memento.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """http://code.activestate.com/recipes/413838-memento-closure/"""
  4. import copy
  5. def Memento(obj, deep=False):
  6. state = (copy.copy, copy.deepcopy)[bool(deep)](obj.__dict__)
  7. def Restore():
  8. obj.__dict__.clear()
  9. obj.__dict__.update(state)
  10. return Restore
  11. class Transaction:
  12. """A transaction guard. This is really just
  13. syntactic suggar arount a memento closure.
  14. """
  15. deep = False
  16. def __init__(self, *targets):
  17. self.targets = targets
  18. self.Commit()
  19. def Commit(self):
  20. self.states = [Memento(target, self.deep) for target in self.targets]
  21. def Rollback(self):
  22. for st in self.states:
  23. st()
  24. class transactional(object):
  25. """Adds transactional semantics to methods. Methods decorated with
  26. @transactional will rollback to entry state upon exceptions.
  27. """
  28. def __init__(self, method):
  29. self.method = method
  30. def __get__(self, obj, T):
  31. def transaction(*args, **kwargs):
  32. state = Memento(obj)
  33. try:
  34. return self.method(obj, *args, **kwargs)
  35. except:
  36. state()
  37. raise
  38. return transaction
  39. class NumObj(object):
  40. def __init__(self, value):
  41. self.value = value
  42. def __repr__(self):
  43. return '<%s: %r>' % (self.__class__.__name__, self.value)
  44. def Increment(self):
  45. self.value += 1
  46. @transactional
  47. def DoStuff(self):
  48. self.value = '1111' # <- invalid value
  49. self.Increment() # <- will fail and rollback
  50. if __name__ == '__main__':
  51. n = NumObj(-1)
  52. print(n)
  53. t = Transaction(n)
  54. try:
  55. for i in range(3):
  56. n.Increment()
  57. print(n)
  58. t.Commit()
  59. print('-- commited')
  60. for i in range(3):
  61. n.Increment()
  62. print(n)
  63. n.value += 'x' # will fail
  64. print(n)
  65. except:
  66. t.Rollback()
  67. print('-- rolled back')
  68. print(n)
  69. print('-- now doing stuff ...')
  70. try:
  71. n.DoStuff()
  72. except:
  73. print('-> doing stuff failed!')
  74. import traceback
  75. traceback.print_exc(0)
  76. pass
  77. print(n)