ソースを参照

script.utils: Convert to string/bytes if needed, add doctests and unittests

git-svn-id: https://svn.osgeo.org/grass/grass/trunk@71394 15284696-431f-4ddb-bdfa-cd5b030d7da7
Pietro Zambelli 7 年 前
コミット
4b2ad31823
2 ファイル変更60 行追加3 行削除
  1. 24 0
      lib/python/script/testsuite/test_utils.py
  2. 36 3
      lib/python/script/utils.py

+ 24 - 0
lib/python/script/testsuite/test_utils.py

@@ -43,6 +43,18 @@ class TestEncode(TestCase):
         """If the input is bytes we should not touch it for encoding"""
         """If the input is bytes we should not touch it for encoding"""
         self.assertEqual(b'Příšerný kůň', utils.encode(b'Příšerný kůň'))
         self.assertEqual(b'Příšerný kůň', utils.encode(b'Příšerný kůň'))
 
 
+    def test_int(self):
+        """If the input is an integer return bytes"""
+        self.assertEqual(b'1234567890', utils.encode(1234567890))
+
+    def test_float(self):
+        """If the input is a float return bytes"""
+        self.assertEqual(b'12345.6789', utils.encode(12345.6789))
+
+    def test_none(self):
+        """If the input is a boolean return bytes"""
+        self.assertEqual(b'None', utils.encode(None))
+
 
 
 class TestDecode(TestCase):
 class TestDecode(TestCase):
     """Tests function `encode` that convert value to unicode."""
     """Tests function `encode` that convert value to unicode."""
@@ -53,6 +65,18 @@ class TestDecode(TestCase):
     def test_unicode(self):
     def test_unicode(self):
         self.assertEqual(u'text', utils.decode(u'text'))
         self.assertEqual(u'text', utils.decode(u'text'))
 
 
+    def test_int(self):
+        """If the input is an integer return bytes"""
+        self.assertEqual(u'1234567890', utils.decode(1234567890))
+
+    def test_float(self):
+        """If the input is a float return bytes"""
+        self.assertEqual(u'12345.6789', utils.decode(12345.6789))
+
+    def test_none(self):
+        """If the input is a boolean return bytes"""
+        self.assertEqual(u'None', utils.decode(None))
+
 
 
 class TestEncodeLcAllC(TestEncode, LcAllC):
 class TestEncodeLcAllC(TestEncode, LcAllC):
     pass
     pass

+ 36 - 3
lib/python/script/utils.py

@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
 """
 """
 Useful functions to be used in Python scripts.
 Useful functions to be used in Python scripts.
 
 
@@ -24,6 +25,14 @@ import locale
 import shlex
 import shlex
 import re
 import re
 
 
+
+try:
+    from builtins import unicode
+except ImportError:
+    # python3
+    unicode = str
+
+
 def float_or_dms(s):
 def float_or_dms(s):
     """Convert DMS to float.
     """Convert DMS to float.
 
 
@@ -163,11 +172,23 @@ def decode(bytes_):
     No-op if parameter is not bytes (assumed unicode string).
     No-op if parameter is not bytes (assumed unicode string).
 
 
     :param bytes bytes_: the bytes to decode
     :param bytes bytes_: the bytes to decode
+
+    Example
+    -------
+
+    >>> decode(b'S\xc3\xbcdtirol')
+    u'Südtirol'
+    >>> decode(u'Südtirol')
+    u'Südtirol'
+    >>> decode(1234)
+    u'1234'
     """
     """
+    if isinstance(bytes_, unicode):
+        return bytes_
     if isinstance(bytes_, bytes):
     if isinstance(bytes_, bytes):
         enc = _get_encoding()
         enc = _get_encoding()
         return bytes_.decode(enc)
         return bytes_.decode(enc)
-    return bytes_
+    return unicode(bytes_)
 
 
 
 
 def encode(string):
 def encode(string):
@@ -177,11 +198,23 @@ def encode(string):
     This ensures garbage in, garbage out.
     This ensures garbage in, garbage out.
 
 
     :param str string: the string to encode
     :param str string: the string to encode
+
+    Example
+    -------
+
+    >>> encode(b'S\xc3\xbcdtirol')
+    b'S\xc3\xbcdtirol'
+    >>> decode(u'Südtirol')
+    b'S\xc3\xbcdtirol'
+    >>> decode(1234)
+    b'1234'
     """
     """
     if isinstance(string, bytes):
     if isinstance(string, bytes):
         return string
         return string
-    enc = _get_encoding()
-    return string.encode(enc)
+    if isinstance(string, unicode):
+        enc = _get_encoding()
+        return string.encode(enc)
+    return bytes(string)
 
 
 
 
 def parse_key_val(s, sep='=', dflt=None, val_type=None, vsep=None):
 def parse_key_val(s, sep='=', dflt=None, val_type=None, vsep=None):