Skip to content

Python 类型注解

类型注解(Type Annotation)是 Python 3.5+ 引入的特性,用于为变量、函数参数和返回值添加类型信息。本章节将详细介绍 Python 中的类型注解概念、语法和使用方法。

为什么需要类型注解?

类型注解的主要作用包括:

  1. 提高代码可读性:类型注解使代码更加清晰,读者可以快速了解变量和函数的类型
  2. 提供 IDE 支持:IDE 可以根据类型注解提供更好的代码补全和错误提示
  3. 便于代码维护:类型注解可以帮助开发者理解代码的预期行为
  4. 类型检查:使用工具如 mypy 可以在运行前检测类型错误
  5. 文档生成:类型注解可以作为文档的一部分,说明函数的参数和返回值类型

基本类型注解

变量类型注解

在 Python 3.6+ 中,可以为变量添加类型注解:

python
# 变量类型注解

# 基本类型
name: str = "Alice"
age: int = 30
height: float = 1.65
is_student: bool = True

# 容器类型
from typing import List, Tuple, Dict, Set

# 列表
numbers: List[int] = [1, 2, 3, 4, 5]
names: List[str] = ["Alice", "Bob", "Charlie"]

# 元组
default_tuple: Tuple[int, str, bool] = (1, "Alice", True)

# 字典
person: Dict[str, any] = {"name": "Alice", "age": 30, "is_student": True}

# 集合
tags: Set[str] = {"python", "programming", "coding"}

# 可选类型
from typing import Optional

# 可选类型(可以是 int 或 None)
optional_number: Optional[int] = None
optional_number = 42

# 联合类型
from typing import Union

# 联合类型(可以是 int 或 str)
union_type: Union[int, str] = 42
union_type = "Hello"

# 类型别名
from typing import TypeAlias

# 类型别名
UserId: TypeAlias = int
UserIds: TypeAlias = List[UserId]

user_id: UserId = 123
user_ids: UserIds = [1, 2, 3, 4, 5]

函数类型注解

python
# 函数类型注解

# 基本函数类型注解
def add(a: int, b: int) -> int:
    """计算两个数的和"""
    return a + b

# 带默认值的参数
def greet(name: str, greeting: str = "Hello") -> str:
    """问候函数"""
    return f"{greeting}, {name}!"

# 可选参数
def get_user(user_id: Optional[int] = None) -> dict:
    """获取用户信息"""
    if user_id:
        return {"id": user_id, "name": "Alice"}
    return {"id": None, "name": "Guest"}

# 可变参数
def sum_numbers(*numbers: int) -> int:
    """计算多个数的和"""
    return sum(numbers)

# 关键字可变参数
def create_person(**kwargs: Union[str, int, bool]) -> dict:
    """创建人物信息"""
    return kwargs

# 混合参数
def process_data(data: List[int], threshold: float = 0.5, *extra: str, **options: bool) -> List[int]:
    """处理数据"""
    result = [x for x in data if x > threshold]
    return result

# 无返回值
def print_message(message: str) -> None:
    """打印消息"""
    print(message)

# 调用函数
print(f"add(1, 2) = {add(1, 2)}")
print(f"greet('Alice') = {greet('Alice')}")
print(f"greet('Bob', 'Hi') = {greet('Bob', 'Hi')}")
print(f"get_user() = {get_user()}")
print(f"get_user(123) = {get_user(123)}")
print(f"sum_numbers(1, 2, 3, 4, 5) = {sum_numbers(1, 2, 3, 4, 5)}")
print(f"create_person(name='Alice', age=30, is_student=True) = {create_person(name='Alice', age=30, is_student=True)}")
print(f"process_data([1, 2, 3, 4, 5], 2) = {process_data([1, 2, 3, 4, 5], 2)}")
print_message("Hello, Type Annotation!")

类和实例类型注解

python
# 类和实例类型注解

from typing import List, Optional

class Person:
    """Person 类"""
    
    # 类属性类型注解
    species: str = "人类"
    
    # 实例属性类型注解
    def __init__(self, name: str, age: int, height: Optional[float] = None):
        """初始化方法"""
        self.name: str = name
        self.age: int = age
        self.height: Optional[float] = height
    
    # 方法类型注解
    def greet(self) -> str:
        """问候方法"""
        return f"你好,我是 {self.name},今年 {self.age} 岁。"
    
    def celebrate_birthday(self) -> None:
        """庆祝生日方法"""
        self.age += 1
        print(f"生日快乐!现在 {self.name}{self.age} 岁了。")

# 类作为类型注解
class Student(Person):
    """Student 类"""
    
    def __init__(self, name: str, age: int, student_id: str, height: Optional[float] = None):
        """初始化方法"""
        super().__init__(name, age, height)
        self.student_id: str = student_id
    
    def study(self, subject: str) -> str:
        """学习方法"""
        return f"{self.name} 正在学习 {subject}。"

# 使用类作为类型注解
def get_student_name(student: Student) -> str:
    """获取学生姓名"""
    return student.name

# 创建实例
person: Person = Person("Alice", 30, 1.65)
student: Student = Student("Bob", 20, "S123", 1.75)

# 调用方法
print(person.greet())
person.celebrate_birthday()

print(student.greet())
print(student.study("Python"))

# 调用函数
print(f"学生姓名:{get_student_name(student)}")

高级类型注解

泛型类型

python
# 泛型类型

from typing import TypeVar, List, Dict, Generic

# 定义类型变量
T = TypeVar('T')
U = TypeVar('U')

# 泛型函数
def first_element(items: List[T]) -> T:
    """获取列表的第一个元素"""
    if items:
        return items[0]
    raise IndexError("列表为空")

# 泛型类
class Stack(Generic[T]):
    """泛型栈类"""
    
    def __init__(self):
        """初始化方法"""
        self.items: List[T] = []
    
    def push(self, item: T) -> None:
        """入栈"""
        self.items.append(item)
    
    def pop(self) -> T:
        """出栈"""
        if self.items:
            return self.items.pop()
        raise IndexError("栈为空")
    
    def is_empty(self) -> bool:
        """检查栈是否为空"""
        return len(self.items) == 0

# 使用泛型函数
numbers: List[int] = [1, 2, 3, 4, 5]
first_num: int = first_element(numbers)
print(f"列表 {numbers} 的第一个元素:{first_num}")

names: List[str] = ["Alice", "Bob", "Charlie"]
first_name: str = first_element(names)
print(f"列表 {names} 的第一个元素:{first_name}")

# 使用泛型类
# 整数栈
int_stack: Stack[int] = Stack()
int_stack.push(1)
int_stack.push(2)
int_stack.push(3)
print(f"整数栈弹出:{int_stack.pop()}")

# 字符串栈
str_stack: Stack[str] = Stack()
str_stack.push("Hello")
str_stack.push("World")
print(f"字符串栈弹出:{str_stack.pop()}")

协议类型

python
# 协议类型

from typing import Protocol

# 定义协议
class Drawable(Protocol):
    """可绘制协议"""
    def draw(self) -> None:
        """绘制方法"""
        ...

# 实现协议的类
class Circle:
    """圆形类"""
    def draw(self) -> None:
        """绘制圆形"""
        print("绘制圆形")

class Rectangle:
    """矩形类"""
    def draw(self) -> None:
        """绘制矩形"""
        print("绘制矩形")

class Triangle:
    """三角形类"""
    def draw(self) -> None:
        """绘制三角形"""
        print("绘制三角形")

# 使用协议作为类型注解
def draw_shape(shape: Drawable) -> None:
    """绘制形状"""
    shape.draw()

# 创建实例
circle: Circle = Circle()
rectangle: Rectangle = Rectangle()
triangle: Triangle = Triangle()

# 调用函数
draw_shape(circle)
draw_shape(rectangle)
draw_shape(triangle)

回调函数类型

python
# 回调函数类型

from typing import Callable

# 回调函数类型注解
def process_data(data: List[int], callback: Callable[[int], int]) -> List[int]:
    """处理数据并应用回调函数"""
    return [callback(item) for item in data]

# 回调函数 1:平方
def square(x: int) -> int:
    """计算平方"""
    return x ** 2

# 回调函数 2:加倍
def double(x: int) -> int:
    """计算加倍"""
    return x * 2

# 使用 lambda 作为回调函数
data: List[int] = [1, 2, 3, 4, 5]

# 使用 square 函数
result1: List[int] = process_data(data, square)
print(f"平方结果:{result1}")

# 使用 double 函数
result2: List[int] = process_data(data, double)
print(f"加倍结果:{result2}")

# 使用 lambda 表达式
result3: List[int] = process_data(data, lambda x: x + 1)
print(f"加 1 结果:{result3}")

类型守卫

python
# 类型守卫

from typing import Union, TypeGuard

# 类型守卫函数
def is_string(value: Union[str, int, float]) -> TypeGuard[str]:
    """检查值是否为字符串"""
    return isinstance(value, str)

def is_integer(value: Union[str, int, float]) -> TypeGuard[int]:
    """检查值是否为整数"""
    return isinstance(value, int)

# 使用类型守卫
def process_value(value: Union[str, int, float]) -> None:
    """处理值"""
    if is_string(value):
        print(f"字符串处理:{value.upper()}")
    elif is_integer(value):
        print(f"整数处理:{value * 2}")
    else:
        print(f"浮点数处理:{value:.2f}")

# 测试
process_value("hello")  # 字符串
process_value(42)        # 整数
process_value(3.14)      # 浮点数

类型检查工具

mypy

mypy 是一个流行的 Python 类型检查工具,可以在运行前检测类型错误。

安装 mypy

bash
pip install mypy

使用 mypy

创建一个 example.py 文件:

python
# example.py

from typing import List, Optional

def add(a: int, b: int) -> int:
    return a + b

def greet(name: str, age: Optional[int] = None) -> str:
    if age:
        return f"Hello, {name}! You are {age} years old."
    return f"Hello, {name}!"

# 类型错误:传递字符串给需要整数的参数
result = add("1", "2")
print(result)

# 类型错误:传递整数给需要字符串的参数
message = greet(42)
print(message)

# 正确用法
result = add(1, 2)
print(result)

message = greet("Alice", 30)
print(message)

运行 mypy 检查:

bash
mypy example.py

输出:

example.py:15: error: Argument 1 to "add" has incompatible type "str"; expected "int"
example.py:15: error: Argument 2 to "add" has incompatible type "str"; expected "int"
example.py:19: error: Argument 1 to "greet" has incompatible type "int"; expected "str"
Found 3 errors in 1 file (checked 1 source file)

mypy 会检测出类型错误,帮助开发者在运行前发现问题。

PyCharm 类型检查

PyCharm 内置了类型检查功能,可以实时检测类型错误:

  1. 打开 PyCharm 并创建一个 Python 文件
  2. 编写带有类型注解的代码
  3. PyCharm 会在有类型错误的地方显示红色波浪线
  4. 鼠标悬停在错误上可以查看详细的错误信息

VS Code 类型检查

在 VS Code 中,可以使用 Python 扩展和 mypy 进行类型检查:

  1. 安装 Python 扩展
  2. 安装 mypy:pip install mypy
  3. 在 VS Code 中打开设置(Ctrl+,)
  4. 搜索 "mypy" 并启用 "Python > Linting: Mypy Enabled"
  5. 编写带有类型注解的代码,VS Code 会显示类型错误

类型注解的最佳实践

1. 为公共 API 添加类型注解

为模块的公共函数、类和方法添加类型注解,提高代码的可读性和可维护性。

2. 为复杂类型使用类型别名

对于复杂的类型,可以使用类型别名提高代码的可读性:

python
from typing import List, Dict, TypeAlias

# 类型别名
UserId: TypeAlias = int
User: TypeAlias = Dict[str, Union[str, int, bool]]
UserList: TypeAlias = List[User]

# 使用类型别名
def get_user(user_id: UserId) -> User:
    """获取用户信息"""
    return {"id": user_id, "name": "Alice", "age": 30}

def get_users() -> UserList:
    """获取用户列表"""
    return [
        {"id": 1, "name": "Alice", "age": 30},
        {"id": 2, "name": "Bob", "age": 25}
    ]

3. 为函数参数和返回值添加类型注解

为函数的参数和返回值添加类型注解,使函数的预期行为更加清晰:

python
from typing import List, Optional

# 好的做法:添加类型注解
def calculate_average(numbers: List[float]) -> Optional[float]:
    """计算平均值"""
    if not numbers:
        return None
    return sum(numbers) / len(numbers)

# 不好的做法:没有类型注解
def calculate_average(numbers):
    """计算平均值"""
    if not numbers:
        return None
    return sum(numbers) / len(numbers)

4. 使用 Optional 表示可能为 None 的值

对于可能为 None 的值,使用 Optional 类型注解:

python
from typing import Optional

# 好的做法:使用 Optional
def find_user(user_id: int) -> Optional[dict]:
    """查找用户"""
    users = {1: {"id": 1, "name": "Alice"}, 2: {"id": 2, "name": "Bob"}}
    return users.get(user_id)

# 不好的做法:没有明确表示可能为 None
def find_user(user_id: int) -> dict:
    """查找用户"""
    users = {1: {"id": 1, "name": "Alice"}, 2: {"id": 2, "name": "Bob"}}
    return users.get(user_id)  # 可能返回 None

5. 为容器类型指定元素类型

为列表、字典等容器类型指定元素的类型:

python
from typing import List, Dict, Set

# 好的做法:指定容器元素类型
numbers: List[int] = [1, 2, 3, 4, 5]
person: Dict[str, Union[str, int]] = {"name": "Alice", "age": 30}
tags: Set[str] = {"python", "programming"}

# 不好的做法:没有指定容器元素类型
numbers = [1, 2, 3, 4, 5]  # 不知道元素类型
person = {"name": "Alice", "age": 30}  # 不知道键值类型
tags = {"python", "programming"}  # 不知道元素类型

6. 不要过度使用类型注解

对于简单的脚本或内部函数,可以适当减少类型注解,避免代码过于冗长:

python
# 简单脚本,不需要过多类型注解
if __name__ == "__main__":
    # 简单变量,不需要类型注解
    x = 10
    y = 20
    print(f"x + y = {x + y}")

实际应用示例

示例 1:数据处理函数

python
# 数据处理函数

from typing import List, Dict, Union, Optional

# 类型别名
DataPoint: TypeAlias = Dict[str, Union[int, float, str]]
Dataset: TypeAlias = List[DataPoint]

# 数据处理函数
def process_dataset(dataset: Dataset, threshold: float = 0.5) -> Dataset:
    """处理数据集"""
    processed: Dataset = []
    for data_point in dataset:
        # 过滤掉值小于阈值的数据
        if isinstance(data_point.get('value'), (int, float)) and data_point['value'] >= threshold:
            processed.append(data_point)
    return processed

# 数据统计函数
def calculate_statistics(dataset: Dataset) -> Dict[str, Optional[float]]:
    """计算数据集统计信息"""
    values: List[float] = []
    for data_point in dataset:
        if isinstance(data_point.get('value'), (int, float)):
            values.append(float(data_point['value']))
    
    if not values:
        return {
            "count": 0,
            "mean": None,
            "min": None,
            "max": None
        }
    
    return {
        "count": len(values),
        "mean": sum(values) / len(values),
        "min": min(values),
        "max": max(values)
    }

# 测试数据
test_dataset: Dataset = [
    {"id": 1, "value": 0.8, "label": "positive"},
    {"id": 2, "value": 0.3, "label": "negative"},
    {"id": 3, "value": 0.9, "label": "positive"},
    {"id": 4, "value": 0.2, "label": "negative"},
    {"id": 5, "value": 0.6, "label": "positive"}
]

# 处理数据
processed_data: Dataset = process_dataset(test_dataset, threshold=0.5)
print("处理后的数据:")
for item in processed_data:
    print(item)

# 计算统计信息
stats: Dict[str, Optional[float]] = calculate_statistics(test_dataset)
print("\n统计信息:")
for key, value in stats.items():
    print(f"{key}: {value}")

示例 2:API 客户端

python
# API 客户端

from typing import Dict, List, Optional, Any
import requests

class APIClient:
    """API 客户端"""
    
    def __init__(self, base_url: str, api_key: str):
        """初始化 API 客户端"""
        self.base_url: str = base_url
        self.api_key: str = api_key
        self.session: requests.Session = requests.Session()
        self.session.headers.update({"Authorization": f"Bearer {api_key}"})
    
    def get_users(self, page: int = 1, page_size: int = 10) -> List[Dict[str, Any]]:
        """获取用户列表"""
        url: str = f"{self.base_url}/users"
        params: Dict[str, int] = {"page": page, "page_size": page_size}
        response: requests.Response = self.session.get(url, params=params)
        response.raise_for_status()
        return response.json().get("data", [])
    
    def get_user(self, user_id: int) -> Optional[Dict[str, Any]]:
        """获取单个用户信息"""
        url: str = f"{self.base_url}/users/{user_id}"
        try:
            response: requests.Response = self.session.get(url)
            response.raise_for_status()
            return response.json().get("data")
        except requests.HTTPError:
            return None
    
    def create_user(self, user_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """创建用户"""
        url: str = f"{self.base_url}/users"
        try:
            response: requests.Response = self.session.post(url, json=user_data)
            response.raise_for_status()
            return response.json().get("data")
        except requests.HTTPError:
            return None
    
    def update_user(self, user_id: int, user_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """更新用户信息"""
        url: str = f"{self.base_url}/users/{user_id}"
        try:
            response: requests.Response = self.session.put(url, json=user_data)
            response.raise_for_status()
            return response.json().get("data")
        except requests.HTTPError:
            return None
    
    def delete_user(self, user_id: int) -> bool:
        """删除用户"""
        url: str = f"{self.base_url}/users/{user_id}"
        try:
            response: requests.Response = self.session.delete(url)
            response.raise_for_status()
            return True
        except requests.HTTPError:
            return False

# 使用示例
if __name__ == "__main__":
    # 创建 API 客户端
    client: APIClient = APIClient("https://api.example.com", "your-api-key")
    
    # 获取用户列表
    users: List[Dict[str, Any]] = client.get_users()
    print(f"获取到 {len(users)} 个用户")
    
    # 创建用户
    new_user: Dict[str, Any] = {
        "name": "Alice",
        "email": "alice@example.com",
        "age": 30
    }
    created_user: Optional[Dict[str, Any]] = client.create_user(new_user)
    if created_user:
        print(f"创建用户成功:{created_user['name']}")
    
    # 更新用户
    if created_user:
        update_data: Dict[str, Any] = {"age": 31}
        updated_user: Optional[Dict[str, Any]] = client.update_user(created_user["id"], update_data)
        if updated_user:
            print(f"更新用户成功:{updated_user['name']},年龄:{updated_user['age']}")
    
    # 删除用户
    if created_user:
        deleted: bool = client.delete_user(created_user["id"])
        if deleted:
            print(f"删除用户成功:{created_user['name']}")

示例 3:配置管理

python
# 配置管理

from typing import Dict, List, Optional, Any, TypeAlias
import json
import os

# 类型别名
Config: TypeAlias = Dict[str, Any]

class ConfigManager:
    """配置管理器"""
    
    def __init__(self, config_file: str = "config.json"):
        """初始化配置管理器"""
        self.config_file: str = config_file
        self.config: Config = self._load_config()
    
    def _load_config(self) -> Config:
        """加载配置文件"""
        if os.path.exists(self.config_file):
            try:
                with open(self.config_file, 'r', encoding='utf-8') as f:
                    return json.load(f)
            except (json.JSONDecodeError, IOError):
                return {}
        return {}
    
    def _save_config(self) -> None:
        """保存配置文件"""
        try:
            with open(self.config_file, 'w', encoding='utf-8') as f:
                json.dump(self.config, f, indent=2, ensure_ascii=False)
        except IOError:
            pass
    
    def get(self, key: str, default: Any = None) -> Any:
        """获取配置值"""
        # 支持点号分隔的路径
        keys: List[str] = key.split('.')
        value: Any = self.config
        
        for k in keys:
            if isinstance(value, dict) and k in value:
                value = value[k]
            else:
                return default
        
        return value
    
    def set(self, key: str, value: Any) -> None:
        """设置配置值"""
        # 支持点号分隔的路径
        keys: List[str] = key.split('.')
        config: Config = self.config
        
        # 遍历键路径,创建不存在的嵌套字典
        for k in keys[:-1]:
            if k not in config or not isinstance(config[k], dict):
                config[k] = {}
            config = config[k]
        
        # 设置最终值
        config[keys[-1]] = value
        self._save_config()
    
    def delete(self, key: str) -> bool:
        """删除配置值"""
        # 支持点号分隔的路径
        keys: List[str] = key.split('.')
        config: Config = self.config
        
        # 遍历键路径
        for k in keys[:-1]:
            if k not in config or not isinstance(config[k], dict):
                return False
            config = config[k]
        
        # 删除键
        if keys[-1] in config:
            del config[keys[-1]]
            self._save_config()
            return True
        return False
    
    def get_all(self) -> Config:
        """获取所有配置"""
        return self.config
    
    def set_all(self, config: Config) -> None:
        """设置所有配置"""
        self.config = config
        self._save_config()

# 使用示例
if __name__ == "__main__":
    # 创建配置管理器
    config_manager: ConfigManager = ConfigManager()
    
    # 设置配置
    config_manager.set("app.name", "My Application")
    config_manager.set("app.version", "1.0.0")
    config_manager.set("database.host", "localhost")
    config_manager.set("database.port", 5432)
    config_manager.set("database.credentials.username", "admin")
    config_manager.set("database.credentials.password", "password")
    
    # 获取配置
    app_name: str = config_manager.get("app.name", "Default App")
    app_version: str = config_manager.get("app.version", "0.0.1")
    db_host: str = config_manager.get("database.host", "127.0.0.1")
    db_port: int = config_manager.get("database.port", 3306)
    db_username: str = config_manager.get("database.credentials.username", "root")
    db_password: str = config_manager.get("database.credentials.password", "")
    
    print(f"应用名称:{app_name}")
    print(f"应用版本:{app_version}")
    print(f"数据库主机:{db_host}")
    print(f"数据库端口:{db_port}")
    print(f"数据库用户名:{db_username}")
    print(f"数据库密码:{db_password}")
    
    # 获取所有配置
    all_config: Config = config_manager.get_all()
    print("\n所有配置:")
    print(json.dumps(all_config, indent=2, ensure_ascii=False))
    
    # 删除配置
    deleted: bool = config_manager.delete("database.credentials.password")
    print(f"\n删除数据库密码:{deleted}")
    
    # 再次获取配置
    db_password = config_manager.get("database.credentials.password", "未设置")
    print(f"数据库密码:{db_password}")

总结

Python 类型注解是一种强大的工具,可以提高代码的可读性、可维护性和可靠性。本章节介绍了:

  1. 基本类型注解:变量、函数参数和返回值的类型注解
  2. 高级类型注解:泛型、协议、回调函数类型等
  3. 类型检查工具:mypy、PyCharm 和 VS Code 的类型检查功能
  4. 最佳实践:为公共 API 添加类型注解、使用类型别名等
  5. 实际应用示例:数据处理、API 客户端、配置管理等

通过合理使用类型注解,可以使 Python 代码更加健壮、易于理解和维护。