浏览代码

Fixed #13 wrong distance of LCS

luozhouyang 5 年之前
父节点
当前提交
99c5618b35
共有 2 个文件被更改,包括 10 次插入10 次删除
  1. 1 2
      strsimpy/longest_common_subsequence.py
  2. 9 8
      strsimpy/longest_common_subsequence_test.py

+ 1 - 2
strsimpy/longest_common_subsequence.py

@@ -39,8 +39,7 @@ class LongestCommonSubsequence(StringDistance):
             raise TypeError("Argument s1 is NoneType.")
         s0_len, s1_len = len(s0), len(s1)
         x, y = s0[:], s1[:]
-        n, m = s0_len + 1, s1_len + 1
-        matrix = [[0] * m for _ in range(n)]
+        matrix = [[0] * (s1_len+1) for _ in range(s0_len + 1)]
         for i in range(1, s0_len + 1):
             for j in range(1, s1_len + 1):
                 if x[i - 1] == y[j - 1]:

+ 9 - 8
strsimpy/longest_common_subsequence_test.py

@@ -23,7 +23,7 @@ import unittest
 from .longest_common_subsequence import LongestCommonSubsequence
 
 
-class TestLongestCommonSubsequence(unittest.TestCase):
+class LongestCommonSubsequenceTest(unittest.TestCase):
 
     def test_longest_common_subsequence(self):
         a = LongestCommonSubsequence()
@@ -31,13 +31,14 @@ class TestLongestCommonSubsequence(unittest.TestCase):
         s1 = ""
         s2 = "上海"
         s3 = "上海市"
-        distance_format = "distance: {:.4}\t between {} and {}"
-        print(distance_format.format(str(a.distance(s0, s1)), s0, s1))
-        print(distance_format.format(str(a.distance(s0, s2)), s0, s2))
-        print(distance_format.format(str(a.distance(s0, s3)), s0, s3))
-        print(distance_format.format(str(a.distance(s1, s2)), s1, s2))
-        print(distance_format.format(str(a.distance(s1, s3)), s1, s3))
-        print(distance_format.format(str(a.distance(s2, s3)), s2, s3))
+
+        self.assertEqual(0, a.distance(s0, s1))
+        self.assertEqual(2, a.distance(s0, s2))
+        self.assertEqual(3, a.distance(s0, s3))
+        self.assertEqual(1, a.distance(s2, s3))
+        self.assertEqual(2, a.length(s2, s3))
+        self.assertEqual(4, a.distance('AGCAT', 'GAC'))
+        self.assertEqual(2, a.length('AGCAT', 'GAC'))
 
 
 if __name__ == "__main__":