python多进程实践

2024年3月15日 14:49 by wst

python高级

引子

多进程编程网上有很多写法,每种方法都是可行的,看个人喜好。

这里举一个实际当中的例子,说明多进程用法。

问题介绍

有很多个json文件,几万个吧。要统计json文件中某个key对应的value的种类和个数。

解决方法

把json文件分成很多份,每份放入一个进程处理。

然后每个进程的计算结果放到一个字典中。

最后在统计所有字典数据,得到汇总数据。

代码如下

import multiprocessing
import json
import random
from pathlib import Path
import time

COUNT_SIZE = 10000

def random_sample_jsons(folder_path):
    "从文件夹中随机取两个文件"
    folder_path = Path(folder_path)
    # 生成文件列表
    json_files = [entry for entry in folder_path.iterdir() if entry.is_file() and entry.name.lower().endswith("json")]
    selected_jsons = random.sample(json_files, 2)
    return selected_jsons


def segment_files(dir_path, proc_count):
    """分割文件
    dir_path: 来源目录,例如:/home/wst/dataset/24-01-11-1
    proc_count: 进程数,用于判断最后能承受多少文件
    """
    all_sons = []
    for son_dir in Path(dir_path).iterdir():
        for mini_dir in son_dir.iterdir():
            all_sons.append(mini_dir)

    # 统计个数并分割
    length = len(all_sons)
    print("总文件数:", length)
    if length > proc_count * COUNT_SIZE:
        raise Exception("需要调整进程个数")

    samples_per_shards = [all_sons[i::proc_count] for i in range(proc_count)]
    return samples_per_shards


def sample_files_key(files, aim_path, index):
    "处理一批采样文件,统计turn数据"
    tar_fn = f"{index:05d}.tar"
    tmp_fn = aim_path.joinpath(tar_fn)
    print("tar-path:", tmp_fn)

    need = {'turn': [], 'brake': []}
    for file_dir in files:
        json_dir = file_dir.joinpath("4dbev_json")
        json_files = random_sample_jsons(json_dir)
        for jf in json_files:
            good_json = json.load(open(jf))
            for obstacle in good_json['annotations']:
                if 'front' in obstacle:
                    light = obstacle['light']
                    need['turn'].append(light['turn_signal_status'])
                    need['brake'].append(light['brake_status'])
    return need


def all_sample_images_debug(from_path, aim_path, proc_count=8):
    """采样所有数据
    from_path: 来源路径,如:/home/wst/dataset/24-02-02-1
    aim_path: 目标路径,如:/home/wst/results
    """
    shards = segment_files(from_path, proc_count)
    pool = multiprocessing.Pool(processes=proc_count)
    need_list = []
    for index, shard in enumerate(shards):
        res = pool.apply_async(sample_files_key, (shard, aim_path, index))
        need_list.append(res)
    pool.close()
    pool.join()
    dic = {
        'turn': {},
        'brake': {}
    }
    for need in need_list:
        data = need.get()
        for key in data:
            for ele in data[key]:
                dic[key][ele] = 1+dic[key].get(ele, 0)
    print("end_turn_brake:", dic)

if __name__ == "__main__":
    # 采样文件
    start = time.time()
    dir_path = "/home/wst/dataset/24-02-02-1"
    aim_path = Path("/home/wst/results/")
    count = all_sample_images_debug(dir_path, aim_path, 32)
    print("sample done!")
    end = time.time()
    print("use time:", int(end-start), "秒")

 

 


Comments(0) Add Your Comment

Not Comment!