菜鸟笔记
提升您的技术认知

C++ 利用linux函数makecontext等实现简单协程

我是一个编程新手,最近了解到协程这个概念,协程可以理解为用户级线程,在用户空间实现调度,在处理异步IO时,可以在子程序中让出cpu交给其他协程,等事件完成再切换到子程序中。当然回调也可以实现,但是使用协程会使程序“看起来”是顺序执行的。

我利用linux系统函数getcontext,makecontext,swapcontext来实现协程之间的切换。

getcontext(ucontext_t*)初始化上下文,makecontext(ucontext_t*,void(*)(void),int args)绑定切换到该上下文时的执行函数和函数的参数;swapcontext(ucontext_t* ouc,ucontext_t* uc);在ouc保存当前上下文,并切换到uc

    typedef struct ucontext {
               struct ucontext *uc_link;          //指向当当前上下文执行完毕后要切换的上下文
               sigset_t         uc_sigmask;
               stack_t          uc_stack;         //当前上下文使用的堆栈
               mcontext_t       uc_mcontext;
               …
           } ucontext_t;

我利用这些实现了一个简单的协程,在一个线程中实现多个协程之间的切换,使用epoll等待IO事件或定时器事件发生后对协程唤醒,当没有事件发生时进行轮询顺序切换到多个协程。下面是线程工作函数:

void	MyUThread::thread_work(void* pvoid) {
	MyUThread* mut = (MyUThread*)pvoid;
	struct epoll_event events[1024];
	while (!mut->shut_down_) {
		int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
		if (ret > 0) {//有事件发生
			for (int i = 0; i < ret; i++) {
				UContextRevent* ctxex = (UContextRevent*)events[i].data.ptr;
				auto ctx = ctxex->ctx_;
				if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
					ctxex->revents_ = UThreadTimerOut;
				}else {
					ctxex->revents_ = events[i].events;
				}
				ctx->resume();
				if (ctx->status_ == UContext::FINISHED) {
					mut->remove_uthread(ctx->ctx_index_);
				}
			}
		}else if (ret == -1) {//异常
			if (errno == EINTR) {
				continue;
			}
			else {
				perror("epoll_wait:");
				abort();
			}
		}else {//epoll_wait超时
			++mut->curr_running_index_;
			if (mut->curr_running_index_ >= mut->curr_max_count_) {
				mut->curr_running_index_ = -1;
				continue;
			}
			auto ctx = mut->ctx_items_[mut->curr_running_index_];
			if (ctx == nullptr || !ctx->register_fd_set.empty()) {
				//如果该ctx协程正在等待某个异步事件,则不对他唤醒
				continue;
			}
			ctx->resume();
			if (ctx->status_ == UContext::FINISHED) {
				mut->remove_uthread(ctx->ctx_index_);
			}
		}
	}
}

使用:

#include <iostream>
#include <functional>
#include <vector>
#include "MyUThread.h"
#include "MyTimer.h"
using namespace std;
using namespace placeholders;

void	test1(UContext*& ctx, int num) {
	cout << "enter test1:" << num << endl;
	MyTimer timer;
	timer.set_once(2, 0);//设置定时器2s
	UContextRevent revents(ctx,timer.get_timerfd());
	ctx->register_event(timer.get_timerfd(),EPOLLIN,&revents);//注册定时器事件
	ctx->attach_timer(&timer);
	cout << "start yield" << endl;
	ctx->yield();            //让出cpu,等待定时器事件
	if (revents.revents_ == UThreadTimerOut) {
		cout << "UThreadTimerOut" << endl;
	}
	ctx->remove_event(timer.get_timerfd());
	cout << "finished test1" << endl;
}

void	test2(UContext*& ctx) {
	cout << "test2 yield" << endl;
	ctx->yield();
	cout << "finished test2" << endl;
}

void	finished() {
	cout << "test1 callback" << endl;
}

int main()
{
	MyUThread mut(2, 8192);//参数:协程最大数量,协程栈大小
	mut.add_task(bind(test1, _1, 99999),finished);//添加任务,参数:工作函数,回调函数
	for (int i = 0; i < 10; i++) {
		mut.add_task(test2);
	}
	mut.join();
	return 0;
}

下面贴出源代码:我还是个小白,把代码贴出来,如果有什么错误希望大家评论告诉我,谢谢(ㅎ-ㅎ;)

UContext.h

#include <functional>
#include <ucontext.h>
#include <memory>
#include <unordered_set>
using namespace std;
/*
EPOLLIN:1
EPOLLOUT:4
EPOLLRDHUP:8192
EPOLLPRI:2
EPOLLERR:8
EPOLLHUP:16
定时器超时标志
*/
#define UThreadTimerOut 3
class UContext;
typedef function<void(UContext*&)>      UContextFunc;
typedef function<void()>                CallBack;
class MyTimer;
struct UContextRevent {
	UContext*    ctx_;
	int          fd_;
	int          revents_;
	UContextRevent(UContext* ctx, int fd,int revents = 0) :
		ctx_(ctx), fd_(fd), revents_(revents){}
};
class UContext {
public:
	UContext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd);
	~UContext();
	void	set_func(UContextFunc func, CallBack callback = 0);
	void	make();
	void	resume();
	void	yield();
	void	attach_timer(MyTimer*);
	void	register_event(int fd, int events, UContextRevent* revents);
	void	remove_event(int fd);
	MyTimer* get_timer();
private:
	static void	work_func(uint32_t low32, uint32_t high32);
public:
        int             ctx_index_;		//当前对象在MyUThread对象的数组中保存的索引
	ucontext_t*     main_ctx_;		//线程主逻辑上下文
	ucontext_t*     ctx_;			//当前上下文
	char*           raw_stack_;		//上下文使用栈空间
	char*           stack_;			//栈空间(安全保护)
	UContextFunc    func_;			//用户任务
	CallBack        callback_;		//任务回调
	int             stack_size_;	//栈大小
	enum { READY = 0, RUNNING, SUSPEND, FINISHED };
	int             status_;		//当前协程状态
	int             epoll_fd_;		//epoll fd
	MyTimer*        mytimer_;		
	unordered_set<int>    register_fd_set;
};

MyUThread.h

#include <thread>
#include <condition_variable>
#include <mutex>
#include <vector>
#include <queue>
#include <ucontext.h>
#include "sys/epoll.h"
using namespace std;
#include "UContext.h"
class MyUThread {
public:
	MyUThread(int max_uthread_count, int stack_size);
	~MyUThread();
	void	add_task(UContextFunc func, CallBack callback = 0);
	void	join();
	void	destory();
private:
	void	remove_uthread(int index);
	int		get_stack_size(int stack_size);
	static void thread_work(void* pvoid);
private:
	int					max_uthread_count_;		//最大协程数量
	int					stack_size_;			//栈大小
	volatile int		curr_running_index_;	//当前执行协程索引
	volatile int		curr_max_count_;		//当前最大协程数量
	volatile int		idle_count_;			//可用协程数量
	bool				shut_down_;				//是否退出
	vector<UContext*>	ctx_items_;				//调度的协程列表
	queue<UContext*>	ctx_ready_queue_;		//就绪队列,等待其它协程退出后,会被添加到ctx_items_中
	thread*				thread_;				//线程
	ucontext_t			main_ctx_;				//主上下文
	mutex				mutex_;					//锁queue
	mutex				mutex_join_;			
	condition_variable	cv_;					//使用锁和条件变量,阻塞等待协程全部执行完毕
	int					epoll_fd_;				//epoll fd
};

MyTimer.h

#pragma once
#include <iostream>
#include <memory>
using namespace std;
class UContext;
class MyTimer
{
public:
	MyTimer();
	~MyTimer();
	void    set_once(int seconds, int millseconds);
	void	set_cycle(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds);
	int		get_timerfd();
	void	stop();
	//获取距离时间到期还有多少纳秒
	int 	get_time();
private:
	void	set_time(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds);
private:
	int					timerfd_;
	struct itimerspec*	timespec_;
};

UContext.cpp

#include "UContext.h"
#include <sys/mman.h>
#include <unistd.h>
#include <assert.h>
#include <sys/epoll.h>
#include "MyTimer.h"

UContext::UContext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd)
	:ctx_index_(index), main_ctx_(main_ctx), ctx_(nullptr),
	raw_stack_(nullptr), stack_(nullptr), func_(nullptr), callback_(nullptr),
	stack_size_(stack_size),
	status_(READY), epoll_fd_(epoll_fd),
	mytimer_(nullptr){
	//创建协程私有栈
	auto page_size = getpagesize();
	raw_stack_ = (char*)mmap(nullptr, stack_size_ + page_size * 2,
		PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
	assert(raw_stack_ != nullptr);
	stack_ = raw_stack_ + page_size;
	assert(mprotect(raw_stack_, page_size, PROT_NONE) == 0);
	assert(mprotect(raw_stack_ + stack_size_ + page_size, page_size, PROT_NONE) == 0);

	ctx_ = new ucontext_t;
	ctx_->uc_flags = 0;
	ctx_->uc_link = main_ctx;
	ctx_->uc_stack.ss_sp = stack_;
	ctx_->uc_stack.ss_size = stack_size_;
	getcontext(ctx_);
}
UContext::~UContext() {
	delete ctx_;
	munmap(raw_stack_, stack_size_ + getpagesize() * 2);
}

void	UContext::set_func(UContextFunc func, CallBack callback) {
	func_ = func;
	callback_ = callback;
}

void	UContext::make() {
	auto ptr = (uintptr_t)this;
	makecontext(ctx_, (void(*)(void))work_func, 2, (uint32_t)ptr, (uint32_t)(ptr >> 32));
}

void	UContext::resume() {
	status_ = RUNNING;
	swapcontext(main_ctx_, ctx_);
}

void	UContext::yield() {
	status_ = SUSPEND;
	swapcontext(ctx_, main_ctx_);
}

void	UContext::attach_timer(MyTimer* timer) {
	mytimer_ = timer;
}

MyTimer* UContext::get_timer() {
	return mytimer_;
}

void	UContext::register_event(int fd, int events, UContextRevent* revents) {
	epoll_event ev;
	ev.data.fd = fd;
	ev.events = events;
	ev.data.ptr = revents;
	epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &ev);
	register_fd_set.insert(fd);
}

void	UContext::remove_event(int fd) {
	epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr);
	if (mytimer_ && mytimer_->get_timerfd() == fd) {
		this->attach_timer(nullptr);
	}
	register_fd_set.erase(fd);
}

void	UContext::work_func(uint32_t low32, uint32_t high32) {
	uintptr_t ptr = (uintptr_t)low32 | ((uintptr_t)high32 << 32);
	UContext * uc = (UContext*)ptr;
	if (uc->func_) {
		uc->func_(uc);
		if (uc->callback_) {
			uc->callback_();
		}
	}
	uc->status_ = FINISHED;
}

MyUThread.cpp

#include "MyUThread.h"
#include "MyTimer.h"
#include <iostream>
#include <cstring>
#include <unistd.h>
#include <assert.h>
#include <functional>
#include <sys/timerfd.h>
using namespace std;
MyUThread::MyUThread(int max_uthread_count, int stack_size) :
	max_uthread_count_(max_uthread_count), 
	stack_size_(get_stack_size(stack_size)),
	curr_running_index_(-1), 
	curr_max_count_(0), 
	idle_count_(max_uthread_count),
	shut_down_(false), thread_(0), epoll_fd_(0)
{
	memset(&main_ctx_, 0, sizeof(main_ctx_));
	ctx_items_.resize(max_uthread_count_, 0);
	epoll_fd_ = epoll_create(100);
	thread_ = new thread(bind(thread_work, (void*)this));
}
MyUThread::~MyUThread() {
	ctx_items_.clear();
	while (!ctx_ready_queue_.empty()) {
		ctx_ready_queue_.pop();
	}
	delete thread_;
}

void	MyUThread::add_task(UContextFunc func, CallBack callback) {
	auto ctx = new UContext(-1, &main_ctx_, stack_size_,epoll_fd_);
	ctx->set_func(func, callback);
	if (curr_max_count_ < max_uthread_count_) {
		ctx->ctx_index_ = curr_max_count_;
		ctx->make();
		ctx_items_[curr_max_count_] = ctx;
		++curr_max_count_;
		--idle_count_;
	}
	else {
		if (idle_count_ > 0) {
			auto index = -1;
			for (int i = 0, j = curr_max_count_ - 1; i <= j; i++, j--) {
				if (ctx_items_[i] == nullptr) {
					index = i;
					break;
				}
				if (ctx_items_[j] == nullptr) {
					index = j;
					break;
				}
			}
			ctx->ctx_index_ = index;
			ctx->make();
			ctx_items_[index] = ctx;
			--idle_count_;
		}
		else {
			lock_guard<mutex> lock(mutex_);
			ctx_ready_queue_.push(ctx);
		}
	}
}

void	MyUThread::join() {
	while (idle_count_ < max_uthread_count_ || !ctx_ready_queue_.empty()) {
		unique_lock<mutex> lock(mutex_join_);
		cv_.wait(lock);
	}
	this->shut_down_ = true;
	this->thread_->join();
}

void	MyUThread::destory() {
	this->shut_down_ = true;
	this->thread_->join();
}

void	MyUThread::remove_uthread(int index) {
	if (index >= 0 && index < curr_max_count_) {
		delete ctx_items_[index];
		ctx_items_[index] = nullptr;
		++idle_count_;
		cv_.notify_all();
	}
	UContext* ctx = nullptr;
	{
		lock_guard<mutex> lock(mutex_);
		if (!ctx_ready_queue_.empty()) {
			ctx = ctx_ready_queue_.front();
			ctx_ready_queue_.pop();
		}
	}
	if (ctx) {
		ctx->ctx_index_ = index;
		ctx->make();
		ctx_items_[index] = ctx;
		--idle_count_;
	}
}

int		MyUThread::get_stack_size(int stack_size) {
	auto page_size = getpagesize();
	if (stack_size < page_size) {
		return page_size;
	}
	int page_count = stack_size / page_size;
	if ((stack_size % page_size) > 0) {
		++page_count;
	}
	return page_count * page_size;
}

void	MyUThread::thread_work(void* pvoid) {
	MyUThread* mut = (MyUThread*)pvoid;
	struct epoll_event events[1024];
	while (!mut->shut_down_) {
		int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
		if (ret > 0) {//有事件发生
			for (int i = 0; i < ret; i++) {
				UContextRevent* ctxex = (UContextRevent*)events[i].data.ptr;
				auto ctx = ctxex->ctx_;
				if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
					ctxex->revents_ = UThreadTimerOut;
				}else {
					ctxex->revents_ = events[i].events;
				}
				ctx->resume();
				if (ctx->status_ == UContext::FINISHED) {
					mut->remove_uthread(ctx->ctx_index_);
				}
			}
		}else if (ret == -1) {//异常
			if (errno == EINTR) {
				continue;
			}
			else {
				perror("epoll_wait:");
				abort();
			}
		}else {//epoll_wait超时
			++mut->curr_running_index_;
			if (mut->curr_running_index_ >= mut->curr_max_count_) {
				mut->curr_running_index_ = -1;
				continue;
			}
			auto ctx = mut->ctx_items_[mut->curr_running_index_];
			if (ctx == nullptr || !ctx->register_fd_set.empty()) {
				//如果该ctx协程正在等待某个异步事件,则不对他唤醒
				continue;
			}
			ctx->resume();
			if (ctx->status_ == UContext::FINISHED) {
				mut->remove_uthread(ctx->ctx_index_);
			}
		}
	}
}

MyTimer.cpp

#include "MyTimer.h"
#include "UContext.h"
#include <cstring>
#include <unistd.h>
#include <sys/epoll.h>
#include <sys/timerfd.h>
#include <fcntl.h>

MyTimer::MyTimer():timerfd_(-1), timespec_(nullptr)
{
	timerfd_ = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK);
	if (timerfd_ == -1) {
		abort();
	}
	auto flags = fcntl(timerfd_, F_GETFL, 0);
	flags |= O_NONBLOCK;
	fcntl(timerfd_, F_SETFL, flags);
	timespec_ = new struct itimerspec;
	memset(timespec_, 0, sizeof(struct itimerspec));
}

MyTimer::~MyTimer()
{
	if (timespec_) {
		delete timespec_;
	}
	close(timerfd_);
}

void    MyTimer::set_once(int seconds, int millseconds) {
	this->set_time(seconds, millseconds, 0, 0);
}

void	MyTimer::set_cycle(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds) {
	this->set_time(seconds, millseconds, intervalSeconds, intervalMillSeconds);
}

int		MyTimer::get_timerfd() {
	return timerfd_;
}

void	MyTimer::set_time(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds) {
	if (timerfd_ == -1 || timespec_ == nullptr)
		return;
	timespec_->it_value.tv_sec = seconds;
	timespec_->it_value.tv_nsec = millseconds * 1000;
	timespec_->it_interval.tv_sec = intervalSeconds;
	timespec_->it_interval.tv_nsec = intervalMillSeconds * 1000;
	if (-1 == timerfd_settime(timerfd_, 0, timespec_, nullptr)) {
		abort();
	}
}

void	MyTimer::stop() {
	this->set_time(0, 0, 0, 0);
}

int 	MyTimer::get_time() {
	struct itimerspec t;
	timerfd_gettime(timerfd_, &t);
	return t.it_value.tv_sec * 10001000 + t.it_value.tv_nsec;
}

这就是我简单编写的全部内容,欢迎指正~