浏览代码

t.rast.what: added new r.what v flag; added test cases

git-svn-id: https://svn.osgeo.org/grass/grass/trunk@71105 15284696-431f-4ddb-bdfa-cd5b030d7da7
Luca Delucchi 8 年之前
父节点
当前提交
8c5ecea1c8
共有 2 个文件被更改,包括 129 次插入55 次删除
  1. 84 55
      temporal/t.rast.what/t.rast.what.py
  2. 45 0
      temporal/t.rast.what/testsuite/test_what.py

+ 84 - 55
temporal/t.rast.what/t.rast.what.py

@@ -109,10 +109,16 @@
 ##% description: Output integer category values, not cell values
 ##% description: Output integer category values, not cell values
 ##%end
 ##%end
 
 
+#%flag
+#% key: v
+#% description: Show the category for vector points map
+#%end
+
 import sys
 import sys
 import copy
 import copy
 import grass.script as gscript
 import grass.script as gscript
 
 
+import pdb
 
 
 ############################################################################
 ############################################################################
 
 
@@ -130,11 +136,12 @@ def main(options, flags):
     order = options["order"]
     order = options["order"]
     layout = options["layout"]
     layout = options["layout"]
     null_value = options["null_value"]
     null_value = options["null_value"]
-    separator = options["separator"]
+    separator = gscript.separator(options["separator"])
     
     
     nprocs = int(options["nprocs"])
     nprocs = int(options["nprocs"])
     write_header = flags["n"]
     write_header = flags["n"]
     use_stdin = flags["i"]
     use_stdin = flags["i"]
+    vcat = flags["v"]
 
 
     #output_cat_label = flags["f"]
     #output_cat_label = flags["f"]
     #output_color = flags["r"]
     #output_color = flags["r"]
@@ -148,6 +155,9 @@ def main(options, flags):
     if not coordinates and not points and not use_stdin: 
     if not coordinates and not points and not use_stdin: 
         gscript.fatal(_("Please specify the coordinates, the points option or use the 'i' flag to pipe coordinate positions to t.rast.what from stdin, to provide the sampling coordinates"))
         gscript.fatal(_("Please specify the coordinates, the points option or use the 'i' flag to pipe coordinate positions to t.rast.what from stdin, to provide the sampling coordinates"))
 
 
+    if vcat and not points:
+        gscript.fatal(_("Flag 'v' required option 'points'"))
+
     if use_stdin:
     if use_stdin:
         coordinates_stdin = str(sys.__stdin__.read())
         coordinates_stdin = str(sys.__stdin__.read())
         # Check if coordinates are given with site names or IDs
         # Check if coordinates are given with site names or IDs
@@ -169,22 +179,9 @@ def main(options, flags):
     maps = sp.get_registered_maps_as_objects(where=where, order=order, 
     maps = sp.get_registered_maps_as_objects(where=where, order=order, 
                                              dbif=dbif)
                                              dbif=dbif)
     dbif.close()
     dbif.close()
-
     if not maps:
     if not maps:
         gscript.fatal(_("Space time raster dataset <%s> is empty") % sp.get_id())
         gscript.fatal(_("Space time raster dataset <%s> is empty") % sp.get_id())
 
 
-    # Setup separator
-    if separator == "pipe":
-        separator = "|"
-    if separator == "comma":
-        separator = ","
-    if separator == "space":
-        separator = " "
-    if separator == "tab":
-        separator = "\t"
-    if separator == "newline":
-        separator = "\n"
-
     # Setup flags are disabled due to test issues
     # Setup flags are disabled due to test issues
     flags = ""
     flags = ""
     #if output_cat_label is True:
     #if output_cat_label is True:
@@ -193,6 +190,8 @@ def main(options, flags):
     #    flags += "r"
     #    flags += "r"
     #if output_cat is True:
     #if output_cat is True:
     #    flags += "i"
     #    flags += "i"
+    if vcat is True:
+        flags += "v"
 
 
     # Configure the r.what module
     # Configure the r.what module
     if points: 
     if points: 
@@ -286,18 +285,18 @@ def main(options, flags):
     # Out the output files in the correct order together
     # Out the output files in the correct order together
     if layout == "row":
     if layout == "row":
         one_point_per_row_output(separator, output_files, output_time_list,
         one_point_per_row_output(separator, output_files, output_time_list,
-                                 output, write_header, site_input)
+                                 output, write_header, site_input, vcat)
     elif layout == "col":
     elif layout == "col":
         one_point_per_col_output(separator, output_files, output_time_list,
         one_point_per_col_output(separator, output_files, output_time_list,
-                                 output, write_header, site_input)
+                                 output, write_header, site_input, vcat)
     else:
     else:
         one_point_per_timerow_output(separator, output_files, output_time_list,
         one_point_per_timerow_output(separator, output_files, output_time_list,
-                                     output, write_header, site_input)
+                                     output, write_header, site_input, vcat)
 
 
 ############################################################################
 ############################################################################
 
 
 def one_point_per_row_output(separator, output_files, output_time_list,
 def one_point_per_row_output(separator, output_files, output_time_list,
-                             output, write_header, site_input):
+                             output, write_header, site_input, vcat):
     """Write one point per row
     """Write one point per row
        output is of type: x,y,start,end,value
        output is of type: x,y,start,end,value
     """
     """
@@ -305,13 +304,15 @@ def one_point_per_row_output(separator, output_files, output_time_list,
     out_file = open(output, 'w') if output != "-" else sys.stdout
     out_file = open(output, 'w') if output != "-" else sys.stdout
     
     
     if write_header is True:
     if write_header is True:
+        out_str = ""
+        if vcat:
+            out_str += "cat{sep}"
         if site_input:
         if site_input:
-            out_file.write("x%(sep)sy%(sep)ssite%(sep)sstart%(sep)send%(sep)svalue\n"\
-                       %({"sep":separator}))
+            out_str += "x{sep}y{sep}site{sep}start{sep}end{sep}value\n"
         else:
         else:
-            out_file.write("x%(sep)sy%(sep)sstart%(sep)send%(sep)svalue\n"\
-                       %({"sep":separator}))
-
+            out_str += "x{sep}y{sep}start{sep}end{sep}value\n"
+        out_file.write(out_str.format(sep=separator))
+    
     for count in range(len(output_files)):
     for count in range(len(output_files)):
         file_name = output_files[count]
         file_name = output_files[count]
         gscript.verbose(_("Transforming r.what output file %s"%(file_name)))
         gscript.verbose(_("Transforming r.what output file %s"%(file_name)))
@@ -319,15 +320,28 @@ def one_point_per_row_output(separator, output_files, output_time_list,
         in_file = open(file_name, "r")
         in_file = open(file_name, "r")
         for line in in_file:
         for line in in_file:
             line = line.split(separator)
             line = line.split(separator)
-            x = line[0]
-            y = line[1]
-            if site_input:
-                site = line[2]
+            if vcat:
+                cat = line[0]
+                x = line[1]
+                y = line[2]
+                values = line[4:]
+                if site_input:
+                    site = line[3]
+                    values = line[5:]
+                
+            else:
+                x = line[0]
+                y = line[1]
+                if site_input:
+                    site = line[2]
+                values = line[3:]
 
 
-            # We ignore the site name
-            values = line[3:]
             for i in range(len(values)):
             for i in range(len(values)):
                 start, end = map_list[i].get_temporal_extent_as_tuple()
                 start, end = map_list[i].get_temporal_extent_as_tuple()
+                if vcat:
+                    cat_str = "{ca}{sep}".format(ca=cat, sep=separator)
+                else:
+                    cat_str = ""
                 if site_input:
                 if site_input:
                     coor_string = "%(x)10.10f%(sep)s%(y)10.10f%(sep)s%(site_name)s%(sep)s"\
                     coor_string = "%(x)10.10f%(sep)s%(y)10.10f%(sep)s%(site_name)s%(sep)s"\
                                %({"x":float(x),"y":float(y),"site_name":str(site),"sep":separator})
                                %({"x":float(x),"y":float(y),"site_name":str(site),"sep":separator})
@@ -338,7 +352,7 @@ def one_point_per_row_output(separator, output_files, output_time_list,
                                %({"start":str(start), "end":str(end),
                                %({"start":str(start), "end":str(end),
                                   "val":(values[i].strip()),"sep":separator})
                                   "val":(values[i].strip()),"sep":separator})
 
 
-                out_file.write(coor_string + time_string)
+                out_file.write(cat_str + coor_string + time_string)
         
         
         in_file.close()
         in_file.close()
     
     
@@ -348,7 +362,7 @@ def one_point_per_row_output(separator, output_files, output_time_list,
 ############################################################################
 ############################################################################
 
 
 def one_point_per_col_output(separator, output_files, output_time_list,
 def one_point_per_col_output(separator, output_files, output_time_list,
-                             output, write_header, site_input):
+                             output, write_header, site_input, vcat):
     """Write one point per col
     """Write one point per col
        output is of type: 
        output is of type: 
        start,end,point_1 value,point_2 value,...,point_n value
        start,end,point_1 value,point_2 value,...,point_n value
@@ -374,38 +388,47 @@ def one_point_per_col_output(separator, output_files, output_time_list,
         
         
         if first is True:
         if first is True:
             if write_header is True:
             if write_header is True:
-                out_file.write("start%(sep)send"%({"sep":separator}))
-                if site_input:
-                    for row in matrix:
-                        x = row[0]
-                        y = row[1]
-                        site = row[2]
-                        out_file.write("%(sep)s%(x)10.10f;%(y)10.10f;%(site_name)s"\
-                                   %({"sep":separator,
-                                      "x":float(x), 
-                                      "y":float(y),
-                                      "site_name":str(site)}))
-                else:
-                    for row in matrix:
+                out_str = "start%(sep)send"%({"sep":separator})
+                for row in matrix:
+                    if vcat:
+                        cat = row[0]
+                        x = row[1]
+                        y = row[2]
+                        out_str += "{sep}{cat}{sep}{x:10.10f};" \
+                                  "{y:10.10f}".format(cat=cat, x=float(x),
+                                                           y=float(y),
+                                                           sep=separator)
+                        if site_input:
+                            site = row[3]
+                            out_str += "{sep}{site}".format(sep=separator,
+                                                            site=site)
+                    else:
                         x = row[0]
                         x = row[0]
                         y = row[1]
                         y = row[1]
-                        out_file.write("%(sep)s%(x)10.10f;%(y)10.10f"\
-                                   %({"sep":separator,
-                                      "x":float(x), 
-                                      "y":float(y)}))
+                        out_str += "{sep}{x:10.10f};" \
+                                   "{y:10.10f}".format(x=float(x), y=float(y),
+                                                       sep=separator)
+                        if site_input:
+                            site = row[2]
+                            out_str += "{sep}{site}".format(sep=separator,
+                                                            site=site)
 
 
-                out_file.write("\n")
+            out_file.write(out_str + "\n")
 
 
         first = False
         first = False
 
 
-        for col in range(num_cols - 3):
+        if vcat:
+            ncol = 4
+        else:
+            ncol = 3
+        for col in range(num_cols - ncol):
             start, end = output_time_list[count][col].get_temporal_extent_as_tuple()
             start, end = output_time_list[count][col].get_temporal_extent_as_tuple()
             time_string = "%(start)s%(sep)s%(end)s"\
             time_string = "%(start)s%(sep)s%(end)s"\
                                %({"start":str(start), "end":str(end),
                                %({"start":str(start), "end":str(end),
                                   "sep":separator})
                                   "sep":separator})
             out_file.write(time_string)
             out_file.write(time_string)
             for row in range(len(matrix)):
             for row in range(len(matrix)):
-                value = matrix[row][col + 3]
+                value = matrix[row][col + ncol]
                 out_file.write("%(sep)s%(value)s"\
                 out_file.write("%(sep)s%(value)s"\
                                    %({"sep":separator,
                                    %({"sep":separator,
                                       "value":value.strip()}))
                                       "value":value.strip()}))
@@ -418,7 +441,7 @@ def one_point_per_col_output(separator, output_files, output_time_list,
 ############################################################################
 ############################################################################
 
 
 def one_point_per_timerow_output(separator, output_files, output_time_list,
 def one_point_per_timerow_output(separator, output_files, output_time_list,
-                             output, write_header, site_input):
+                             output, write_header, site_input, vcat):
     """Use the original layout of the r.what output and print instead of 
     """Use the original layout of the r.what output and print instead of 
        the raster names, the time stamps as header
        the raster names, the time stamps as header
        
        
@@ -441,10 +464,14 @@ def one_point_per_timerow_output(separator, output_files, output_time_list,
 
 
         if write_header:
         if write_header:
             if first is True:
             if first is True:
+                if vcat:
+                    header = "cat{sep}".format(sep=separator)
+                else:
+                    header = ""
                 if site_input:
                 if site_input:
-                    header = "x%(sep)sy%(sep)ssite"%({"sep":separator})
+                    header += "x%(sep)sy%(sep)ssite"%({"sep":separator})
                 else:
                 else:
-                    header = "x%(sep)sy"%({"sep":separator})
+                    header += "x%(sep)sy"%({"sep":separator})
             for map in map_list:
             for map in map_list:
                 start, end = map.get_temporal_extent_as_tuple()
                 start, end = map.get_temporal_extent_as_tuple()
                 time_string = "%(sep)s%(start)s;%(end)s"\
                 time_string = "%(sep)s%(start)s;%(end)s"\
@@ -458,7 +485,9 @@ def one_point_per_timerow_output(separator, output_files, output_time_list,
             cols = lines[i].split(separator)
             cols = lines[i].split(separator)
 
 
             if first is True:
             if first is True:
-                if site_input:
+                if vcat and site_input:
+                    matrix.append(cols[:4])
+                elif vcat or site_input:
                     matrix.append(cols[:3])
                     matrix.append(cols[:3])
                 else:
                 else:
                     matrix.append(cols[:2])
                     matrix.append(cols[:2])

+ 45 - 0
temporal/t.rast.what/testsuite/test_what.py

@@ -69,6 +69,14 @@ class TestRasterWhat(TestCase):
         self.assertFileMd5("out_row_coords.txt",
         self.assertFileMd5("out_row_coords.txt",
                            "cd917ac4848786f1b944512eed1da5bc", text=True)
                            "cd917ac4848786f1b944512eed1da5bc", text=True)
 
 
+    def test_row_output_cat(self):
+        self.assertModule("t.rast.what", strds="A", output="out_row_cat.txt",
+                          points="points", flags="nv", layout="row",
+                          nprocs=1, overwrite=True, verbose=True)
+
+        self.assertFileMd5("out_row_cat.txt",
+                           "bddde54bf4f40c41a17305c3442980e5", text=True)
+
     def test_col_output(self):
     def test_col_output(self):
         self.assertModule("t.rast.what", strds="A", output="out_col.txt",
         self.assertModule("t.rast.what", strds="A", output="out_col.txt",
                           points="points", flags="n", layout="col",
                           points="points", flags="n", layout="col",
@@ -93,6 +101,22 @@ class TestRasterWhat(TestCase):
         self.assertFileMd5("out_timerow.txt",
         self.assertFileMd5("out_timerow.txt",
                            "129fe0b63019e505232efa20ad42c03a", text=True)
                            "129fe0b63019e505232efa20ad42c03a", text=True)
 
 
+    def test_col_output_cat(self):
+        self.assertModule("t.rast.what", strds="A", output="out_col_cat.txt",
+                          points="points", flags="nv", layout="col",
+                          nprocs=1, overwrite=True, verbose=True)
+
+        self.assertFileMd5("out_col_cat.txt",
+                           "e1d8e6651b3bff1c35e366e48d634db4", text=True)
+
+    def test_timerow_output_cat(self):
+        self.assertModule("t.rast.what", strds="A", output="out_col_trow.txt",
+                          points="points", flags="nv", layout="timerow",
+                          nprocs=1, overwrite=True, verbose=True)
+
+        self.assertFileMd5("out_col_trow.txt",
+                           "55e2ff8ddaeb731a73daca48adf2768d", text=True)
+
     def test_timerow_output_coords(self):
     def test_timerow_output_coords(self):
         self.assertModule("t.rast.what", strds="A", output="out_timerow_coords.txt",
         self.assertModule("t.rast.what", strds="A", output="out_timerow_coords.txt",
                           coordinates=(30, 30, 45, 45), flags="n", layout="timerow",
                           coordinates=(30, 30, 45, 45), flags="n", layout="timerow",
@@ -122,6 +146,27 @@ class TestRasterWhat(TestCase):
 """
 """
         self.assertLooksLike(text, t_rast_what.outputs.stdout)
         self.assertLooksLike(text, t_rast_what.outputs.stdout)
 
 
+    def test_row_stdout_where_parallel_cat(self):
+
+        t_rast_what = SimpleModule("t.rast.what", strds="A", output="-",
+                                   points="points", flags="nv",
+                                   where="start_time > '2001-03-01'",
+                                   nprocs=4, overwrite=True, verbose=True)
+        self.assertModule(t_rast_what)
+
+        text = """cat|x|y|start|end|value
+1|115.0043586274|36.3593955783|2001-04-01 00:00:00|2001-07-01 00:00:00|200
+2|79.6816763826|45.2391522853|2001-04-01 00:00:00|2001-07-01 00:00:00|200
+3|97.4892579600|79.2347263950|2001-04-01 00:00:00|2001-07-01 00:00:00|200
+1|115.0043586274|36.3593955783|2001-07-01 00:00:00|2001-10-01 00:00:00|300
+2|79.6816763826|45.2391522853|2001-07-01 00:00:00|2001-10-01 00:00:00|300
+3|97.4892579600|79.2347263950|2001-07-01 00:00:00|2001-10-01 00:00:00|300
+1|115.0043586274|36.3593955783|2001-10-01 00:00:00|2002-01-01 00:00:00|400
+2|79.6816763826|45.2391522853|2001-10-01 00:00:00|2002-01-01 00:00:00|400
+3|97.4892579600|79.2347263950|2001-10-01 00:00:00|2002-01-01 00:00:00|400
+"""
+        self.assertLooksLike(text, t_rast_what.outputs.stdout)
+
     def test_row_stdout_where_parallel2(self):
     def test_row_stdout_where_parallel2(self):
         """Here without output definition, the default is used then"""
         """Here without output definition, the default is used then"""