Skip to content

Python 多线程编程

多线程是 Python 中实现并发编程的一种重要方式,它允许程序同时执行多个任务。本章节将详细介绍 Python 中的多线程编程概念、使用方法和最佳实践。

什么是多线程?

线程是程序执行的最小单位,一个进程可以包含多个线程。多线程编程是指在一个程序中同时运行多个线程,每个线程执行不同的任务。

多线程的优点

  1. 提高程序响应速度:当一个线程执行 I/O 操作时,其他线程可以继续执行
  2. 充分利用 CPU 资源:在多核 CPU 上,多线程可以并行执行
  3. 简化程序设计:将不同的任务分配给不同的线程,使程序结构更清晰

多线程的缺点

  1. 线程安全问题:多个线程访问共享资源时可能导致数据竞争
  2. 死锁:多个线程相互等待对方释放资源
  3. 上下文切换开销:线程之间的切换会产生一定的开销
  4. GIL 限制:在 CPython 中,全局解释器锁(GIL)限制了多线程的并行性能

Python 中的多线程模块

Python 提供了两个主要的多线程模块:

  1. threading:更高级的多线程模块,提供了更多功能
  2. _thread:低级别的多线程模块,一般不推荐使用

本章节主要介绍 threading 模块的使用。

基本线程创建

使用 threading.Thread 类

python
import threading
import time

# 定义线程函数
def thread_function(name):
    """线程函数"""
    print(f"线程 {name} 开始执行")
    time.sleep(2)  # 模拟耗时操作
    print(f"线程 {name} 执行完成")

# 创建线程
thread1 = threading.Thread(target=thread_function, args=("1",))
thread2 = threading.Thread(target=thread_function, args=("2",))

# 启动线程
print("主线程开始")
thread1.start()
thread2.start()

# 等待线程完成
thread1.join()
thread2.join()

print("主线程结束")

继承 threading.Thread 类

python
import threading
import time

# 继承 Thread 类
class MyThread(threading.Thread):
    """自定义线程类"""
    
    def __init__(self, name):
        """初始化线程"""
        super().__init__()
        self.name = name
    
    def run(self):
        """线程运行函数"""
        print(f"线程 {self.name} 开始执行")
        time.sleep(2)  # 模拟耗时操作
        print(f"线程 {self.name} 执行完成")

# 创建线程
thread1 = MyThread("1")
thread2 = MyThread("2")

# 启动线程
print("主线程开始")
thread1.start()
thread2.start()

# 等待线程完成
thread1.join()
thread2.join()

print("主线程结束")

线程同步

锁(Lock)

当多个线程访问共享资源时,需要使用锁来确保数据的一致性。

python
import threading
import time

# 共享变量
counter = 0

# 创建锁
lock = threading.Lock()

# 线程函数
def increment_counter():
    """增加计数器"""
    global counter
    for _ in range(100000):
        # 获取锁
        lock.acquire()
        try:
            # 临界区
            counter += 1
        finally:
            # 释放锁
            lock.release()

# 创建线程
threads = []
for i in range(5):
    thread = threading.Thread(target=increment_counter)
    threads.append(thread)

# 启动线程
print("主线程开始")
print(f"初始计数器值: {counter}")

for thread in threads:
    thread.start()

# 等待线程完成
for thread in threads:
    thread.join()

print(f"最终计数器值: {counter}")
print("主线程结束")

使用 with 语句简化锁操作

python
import threading
import time

# 共享变量
counter = 0

# 创建锁
lock = threading.Lock()

# 线程函数
def increment_counter():
    """增加计数器"""
    global counter
    for _ in range(100000):
        # 使用 with 语句自动获取和释放锁
        with lock:
            # 临界区
            counter += 1

# 创建线程
threads = []
for i in range(5):
    thread = threading.Thread(target=increment_counter)
    threads.append(thread)

# 启动线程
print("主线程开始")
print(f"初始计数器值: {counter}")

for thread in threads:
    thread.start()

# 等待线程完成
for thread in threads:
    thread.join()

print(f"最终计数器值: {counter}")
print("主线程结束")

可重入锁(RLock)

可重入锁允许同一个线程多次获取锁,而不会导致死锁。

python
import threading
import time

# 创建可重入锁
lock = threading.RLock()

# 线程函数
def nested_function(level):
    """嵌套函数"""
    if level > 3:
        return
    
    with lock:
        print(f"线程 {threading.current_thread().name} 在级别 {level} 获取锁")
        time.sleep(0.1)
        nested_function(level + 1)
        print(f"线程 {threading.current_thread().name} 在级别 {level} 释放锁")

# 创建线程
thread1 = threading.Thread(target=nested_function, args=(1,), name="Thread-1")
thread2 = threading.Thread(target=nested_function, args=(1,), name="Thread-2")

# 启动线程
print("主线程开始")
thread1.start()
thread2.start()

# 等待线程完成
thread1.join()
thread2.join()

print("主线程结束")

信号量(Semaphore)

信号量用于控制对有限资源的访问。

python
import threading
import time
import random

# 创建信号量,最多允许 3 个线程同时访问
semaphore = threading.Semaphore(3)

# 线程函数
def worker(name):
    """工作线程"""
    print(f"工人 {name} 等待使用工具")
    
    # 获取信号量
    with semaphore:
        print(f"工人 {name} 开始使用工具")
        # 模拟使用工具的时间
        time.sleep(random.randint(1, 3))
        print(f"工人 {name} 完成使用工具")

# 创建线程
threads = []
for i in range(10):
    thread = threading.Thread(target=worker, args=(f"{i+1}",))
    threads.append(thread)

# 启动线程
print("主线程开始")

for thread in threads:
    thread.start()

# 等待线程完成
for thread in threads:
    thread.join()

print("主线程结束")

事件(Event)

事件用于线程间的通知机制。

python
import threading
import time

# 创建事件
event = threading.Event()

# 消费者线程函数
def consumer():
    """消费者线程"""
    print("消费者等待产品")
    # 等待事件触发
    event.wait()
    print("消费者收到产品,开始消费")
    time.sleep(2)
    print("消费者消费完成")

# 生产者线程函数
def producer():
    """生产者线程"""
    print("生产者开始生产产品")
    time.sleep(3)  # 模拟生产时间
    print("生产者生产完成")
    # 触发事件
    event.set()
    print("生产者通知消费者")

# 创建线程
consumer_thread = threading.Thread(target=consumer)
producer_thread = threading.Thread(target=producer)

# 启动线程
print("主线程开始")
consumer_thread.start()
producer_thread.start()

# 等待线程完成
consumer_thread.join()
producer_thread.join()

print("主线程结束")

条件变量(Condition)

条件变量用于线程间的复杂同步。

python
import threading
import time

# 创建条件变量
condition = threading.Condition()

# 共享队列
queue = []
MAX_ITEMS = 5

# 生产者线程函数
def producer():
    """生产者线程"""
    for i in range(10):
        with condition:
            # 等待队列不满
            while len(queue) >= MAX_ITEMS:
                print(f"队列已满,生产者等待")
                condition.wait()
            
            # 生产产品
            item = f"产品 {i}"
            queue.append(item)
            print(f"生产者生产: {item}, 当前队列: {queue}")
            
            # 通知消费者
            condition.notify_all()
            
        # 模拟生产时间
        time.sleep(0.5)

# 消费者线程函数
def consumer(name):
    """消费者线程"""
    for _ in range(5):
        with condition:
            # 等待队列不为空
            while not queue:
                print(f"队列为空,消费者 {name} 等待")
                condition.wait()
            
            # 消费产品
            item = queue.pop(0)
            print(f"消费者 {name} 消费: {item}, 当前队列: {queue}")
            
            # 通知生产者
            condition.notify_all()
            
        # 模拟消费时间
        time.sleep(1)

# 创建线程
producer_thread = threading.Thread(target=producer)
consumer_thread1 = threading.Thread(target=consumer, args=("1",))
consumer_thread2 = threading.Thread(target=consumer, args=("2",))

# 启动线程
print("主线程开始")
producer_thread.start()
consumer_thread1.start()
consumer_thread2.start()

# 等待线程完成
producer_thread.join()
consumer_thread1.join()
consumer_thread2.join()

print("主线程结束")

定时器(Timer)

定时器用于在指定时间后执行函数。

python
import threading
import time

# 定时器回调函数
def timer_callback():
    """定时器回调函数"""
    print(f"定时器触发,当前时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# 创建定时器,2秒后执行
print(f"主线程开始,当前时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
timer = threading.Timer(2, timer_callback)

# 启动定时器
timer.start()

# 主线程继续执行
print("主线程继续执行")
time.sleep(1)
print("主线程等待定时器完成")

# 等待定时器完成
timer.join()

print("主线程结束")

# 取消定时器示例
def cancel_callback():
    print("这个回调不会执行,因为定时器被取消了")

print("\n测试取消定时器:")
timer2 = threading.Timer(2, cancel_callback)
timer2.start()
print("取消定时器")
timer2.cancel()
time.sleep(3)
print("主线程结束")

线程池

线程池用于管理和复用线程,减少线程创建和销毁的开销。

使用 concurrent.futures.ThreadPoolExecutor

python
import concurrent.futures
import time
import random

# 工作函数
def worker(name, seconds):
    """工作函数"""
    print(f"任务 {name} 开始执行,需要 {seconds} 秒")
    time.sleep(seconds)
    print(f"任务 {name} 执行完成")
    return f"任务 {name} 的结果"

# 创建线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    # 提交任务
    future_to_task = {}
    for i in range(5):
        seconds = random.randint(1, 3)
        future = executor.submit(worker, f"{i+1}", seconds)
        future_to_task[future] = f"任务 {i+1}"
    
    print("所有任务已提交")
    
    # 等待任务完成并获取结果
    for future in concurrent.futures.as_completed(future_to_task):
        task = future_to_task[future]
        try:
            result = future.result()
            print(f"{task} 结果: {result}")
        except Exception as e:
            print(f"{task} 发生异常: {e}")

print("所有任务完成")

# 使用 map 方法
print("\n使用 map 方法:")
tasks = [(f"{i+1}", random.randint(1, 3)) for i in range(5)]

with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    # 使用 map 提交任务
    results = executor.map(lambda x: worker(x[0], x[1]), tasks)
    
    # 获取结果
    for result in results:
        print(f"结果: {result}")

print("所有任务完成")

线程安全的数据结构

Python 提供了一些线程安全的数据结构,位于 queue 模块中。

队列(Queue)

python
import threading
import queue
import time
import random

# 创建队列
q = queue.Queue()

# 生产者线程函数
def producer():
    """生产者线程"""
    for i in range(10):
        item = f"产品 {i}"
        q.put(item)
        print(f"生产者生产: {item}")
        time.sleep(random.uniform(0.1, 0.5))

# 消费者线程函数
def consumer(name):
    """消费者线程"""
    while True:
        try:
            # 超时获取
            item = q.get(timeout=2)
            print(f"消费者 {name} 消费: {item}")
            q.task_done()
            time.sleep(random.uniform(0.5, 1))
        except queue.Empty:
            print(f"消费者 {name} 等待超时,退出")
            break

# 创建线程
producer_thread = threading.Thread(target=producer)
consumer_thread1 = threading.Thread(target=consumer, args=("1",))
consumer_thread2 = threading.Thread(target=consumer, args=("2",))

# 启动线程
print("主线程开始")
producer_thread.start()
consumer_thread1.start()
consumer_thread2.start()

# 等待生产者完成
producer_thread.join()

# 等待队列处理完成
q.join()

print("主线程结束")

优先级队列(PriorityQueue)

python
import threading
import queue
import time

# 创建优先级队列
q = queue.PriorityQueue()

# 生产者线程函数
def producer():
    """生产者线程"""
    tasks = [
        (3, "低优先级任务"),
        (1, "高优先级任务"),
        (2, "中优先级任务"),
        (1, "另一个高优先级任务"),
        (3, "另一个低优先级任务")
    ]
    
    for priority, task in tasks:
        q.put((priority, task))
        print(f"生产者添加: {task} (优先级: {priority})")
        time.sleep(0.5)

# 消费者线程函数
def consumer():
    """消费者线程"""
    while True:
        try:
            # 超时获取
            priority, task = q.get(timeout=2)
            print(f"消费者处理: {task} (优先级: {priority})")
            q.task_done()
            time.sleep(1)
        except queue.Empty:
            print("队列为空,消费者退出")
            break

# 创建线程
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)

# 启动线程
print("主线程开始")
producer_thread.start()
consumer_thread.start()

# 等待生产者完成
producer_thread.join()

# 等待队列处理完成
q.join()

print("主线程结束")

后进先出队列(LifoQueue)

python
import threading
import queue
import time

# 创建后进先出队列
q = queue.LifoQueue()

# 生产者线程函数
def producer():
    """生产者线程"""
    for i in range(5):
        item = f"项目 {i}"
        q.put(item)
        print(f"生产者添加: {item}")
        time.sleep(0.5)

# 消费者线程函数
def consumer():
    """消费者线程"""
    while True:
        try:
            # 超时获取
            item = q.get(timeout=2)
            print(f"消费者处理: {item}")
            q.task_done()
            time.sleep(1)
        except queue.Empty:
            print("队列为空,消费者退出")
            break

# 创建线程
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)

# 启动线程
print("主线程开始")
producer_thread.start()
consumer_thread.start()

# 等待生产者完成
producer_thread.join()

# 等待队列处理完成
q.join()

print("主线程结束")

线程局部存储

线程局部存储(Thread Local Storage)用于为每个线程提供独立的变量副本。

python
import threading
import time

# 创建线程局部存储
local_data = threading.local()

# 线程函数
def thread_function(name):
    """线程函数"""
    # 为每个线程设置独立的变量
    local_data.value = name
    print(f"线程 {name} 本地值: {local_data.value}")
    
    # 模拟耗时操作
    time.sleep(1)
    
    # 再次访问本地变量
    print(f"线程 {name} 本地值(操作后): {local_data.value}")

# 创建线程
thread1 = threading.Thread(target=thread_function, args=("1",))
thread2 = threading.Thread(target=thread_function, args=("2",))

# 启动线程
print("主线程开始")
thread1.start()
thread2.start()

# 等待线程完成
thread1.join()
thread2.join()

print("主线程结束")

多线程的实际应用

示例 1:并行下载文件

python
import threading
import requests
import time
import os

# 下载函数
def download_file(url, filename):
    """下载文件"""
    print(f"开始下载: {url} -> {filename}")
    response = requests.get(url, stream=True)
    with open(filename, 'wb') as f:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
    print(f"下载完成: {filename}")

# 下载任务
urls = [
    ("https://www.python.org/static/img/python-logo.png", "python-logo.png"),
    ("https://www.python.org/static/img/python-logo-master-v3-TM.png", "python-logo-master.png"),
    ("https://www.python.org/static/img/psf-logo.png", "psf-logo.png")
]

# 使用多线程下载
print("开始下载文件")
start_time = time.time()

threads = []
for url, filename in urls:
    thread = threading.Thread(target=download_file, args=(url, filename))
    threads.append(thread)
    thread.start()

# 等待所有线程完成
for thread in threads:
    thread.join()

end_time = time.time()
print(f"所有文件下载完成,耗时: {end_time - start_time:.2f} 秒")

# 清理文件
for _, filename in urls:
    if os.path.exists(filename):
        os.remove(filename)
        print(f"删除文件: {filename}")

示例 2:并发处理数据

python
import threading
import time
import concurrent.futures

# 处理函数
def process_data(data):
    """处理数据"""
    print(f"开始处理数据: {data}")
    time.sleep(1)  # 模拟处理时间
    result = data * 2
    print(f"数据 {data} 处理完成,结果: {result}")
    return result

# 数据列表
data_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# 使用线程池处理数据
print("开始处理数据")
start_time = time.time()

with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
    # 提交任务
    futures = [executor.submit(process_data, data) for data in data_list]
    
    # 获取结果
    results = []
    for future in concurrent.futures.as_completed(futures):
        results.append(future.result())

end_time = time.time()
print(f"所有数据处理完成,耗时: {end_time - start_time:.2f} 秒")
print(f"处理结果: {results}")

示例 3:多线程 Web 服务器

python
import socket
import threading
import time

# 处理客户端连接
def handle_client(client_socket, client_address):
    """处理客户端连接"""
    print(f"接受到来自 {client_address} 的连接")
    
    try:
        # 接收数据
        data = client_socket.recv(1024)
        print(f"从 {client_address} 收到: {data.decode('utf-8')}")
        
        # 模拟处理时间
        time.sleep(1)
        
        # 发送响应
        response = "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, World!"
        client_socket.sendall(response.encode('utf-8'))
    finally:
        # 关闭连接
        client_socket.close()
        print(f"与 {client_address} 的连接已关闭")

# 创建服务器
def start_server():
    """启动服务器"""
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server_socket.bind(('localhost', 8080))
    server_socket.listen(5)
    print("服务器启动,监听端口 8080")
    print("访问地址: http://localhost:8080")
    
    try:
        while True:
            # 接受连接
            client_socket, client_address = server_socket.accept()
            
            # 创建线程处理客户端
            client_thread = threading.Thread(
                target=handle_client, 
                args=(client_socket, client_address)
            )
            client_thread.daemon = True
            client_thread.start()
    finally:
        # 关闭服务器
        server_socket.close()

# 启动服务器
if __name__ == "__main__":
    start_server()

多线程的最佳实践

1. 避免全局变量

  • 使用线程局部存储:为每个线程提供独立的变量
  • 使用队列:通过队列在线程间安全传递数据
  • 使用锁:当必须使用共享变量时,使用锁确保线程安全

2. 合理设置线程数

  • CPU 密集型任务:线程数不宜超过 CPU 核心数
  • I/O 密集型任务:线程数可以适当增加
  • 使用线程池:管理和复用线程,避免线程创建过多

3. 异常处理

  • 捕获线程中的异常:确保线程中的异常不会导致整个程序崩溃
  • 使用 Future:通过 Future 获取线程执行结果和异常

4. 线程生命周期管理

  • 设置线程为守护线程:当主线程退出时,守护线程会自动退出
  • 使用 join():等待线程完成,避免主线程过早退出
  • 避免线程阻塞:确保线程不会无限期阻塞

5. 性能优化

  • 减少锁的范围:只在必要的代码段使用锁
  • 使用无锁数据结构:如 queue.Queue
  • 避免频繁的上下文切换:减少线程数量,合并任务

6. 调试技巧

  • 使用日志:记录线程的执行状态
  • 使用 threading.enumerate():查看当前所有线程
  • 使用 threading.current_thread():获取当前线程信息
  • 使用 timeouts:避免线程无限期等待

总结

本章节介绍了 Python 中的多线程编程,包括:

  1. 基本线程创建:使用 threading.Thread 类和继承 Thread 类
  2. 线程同步:锁、信号量、事件、条件变量、定时器
  3. 线程池:使用 concurrent.futures.ThreadPoolExecutor
  4. 线程安全的数据结构:Queue、PriorityQueue、LifoQueue
  5. 线程局部存储:使用 threading.local()
  6. 实际应用示例:并行下载文件、并发处理数据、多线程 Web 服务器
  7. 最佳实践:避免全局变量、合理设置线程数、异常处理、线程生命周期管理、性能优化、调试技巧

掌握多线程编程,可以提高程序的并发性能,特别是在处理 I/O 密集型任务时。但同时也要注意线程安全问题,避免数据竞争和死锁等问题。