Functools -- total_ordering
functools.total_ordering 是一个 类装饰器。
它可以让你只实现 1 个比较方法 + __eq__,Python 自动为你生成剩下的比较方法:
| 你实现 | 自动补全 |
|---|---|
__eq__ + __lt__ |
自动生成 __le__, __gt__, __ge__ |
__eq__ + __gt__ |
自动生成 __ge__, __lt__, __le__ |
__eq__ + __le__ |
自动生成 __lt__, __gt__, __ge__ |
- 原本你要写 4 个方法,
- 现在只要写 2 个。
🟦 2. 最小核心示例
from functools import total_ordering
@total_ordering
class User:
def __init__(self, score):
self.score = score
def __eq__(self, other):
return self.score == other.score
def __lt__(self, other):
return self.score < other.score
u1 = User(10)
u2 = User(20)
print(u1 < u2) # True
print(u1 <= u2) # True(自动生成)
print(u1 > u2) # False(自动生成)
print(u1 >= u2) # False(自动生成)
from functools import total_ordering
@total_ordering
class ClassName:
def __eq__(self, other): ...
def __lt__(self, other): ...
只需两点:
-
必须实现
__eq__ -
其他五个比较方法至少实现一个:
__lt____le____gt____ge__
Python 会自动补全剩下的。
@total_ordering
class Task:
def __init__(self, priority):
self.priority = priority
def __eq__(self, other):
return self.priority == other.priority
def __lt__(self, other):
return self.priority < other.priority
tasks = [Task(5), Task(1), Task(3)]
tasks_sorted = sorted(tasks) # 自动用 < 排序
@total_ordering
class News:
def __init__(self, ts):
self.ts = ts
def __eq__(self, other):
return self.ts == other.ts
def __lt__(self, other):
return self.ts < other.ts
你就可以:
sorted_news = sorted(news_list, reverse=True)
例如你在 FastAPI + SQLAlchemy 项目中:
@total_ordering
class Stock:
def __init__(self, code, pe):
self.code = code
self.pe = pe
def __eq__(self, other):
return self.pe == other.pe
def __lt__(self, other):
return self.pe < other.pe
直接对 Query 出来的对象排序:
sorted(stock_list)
你可以做一个“权重对象”:
@total_ordering
class RankedItem:
def __init__(self, item, weight):
self.item = item
self.weight = weight
def __eq__(self, other):
return self.weight == other.weight
def __lt__(self, other):
return self.weight < other.weight
否则会报错:
TypeError: '<' not supported between instances
你可以加保护:
if not isinstance(other, self.__class__):
return NotImplemented
如果你希望对象可作为 dict key:
@total_ordering
class User:
__hash__ = object.__hash__
例如:公司新闻、股票、权重、优先级、时间戳 —— 都适合。
from functools import total_ordering
@total_ordering
class BaseComparable:
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.key() == other.key()
def __lt__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.key() < other.key()
def key(self):
raise NotImplementedError("子类必须实现 key() 返回比较依据")
子类只需要覆盖一个方法:
class News(BaseComparable):
def __init__(self, ts):
self.ts = ts
def key(self):
return self.ts
total_ordering:根据已实现的比较方法自动生成剩下的比较方法
必须实现:__eq__
还需至少实现一个:
__lt____le____gt____ge__
然后 Python 自动补全其余:<, <=, >, >= 全部可用
典型用法:
@total_ordering
class C:
def __eq__(self, other): ...
def __lt__(self, other): ...