我是一个编程新手,最近了解到协程这个概念,协程可以理解为用户级线程,在用户空间实现调度,在处理异步IO时,可以在子程序中让出cpu交给其他协程,等事件完成再切换到子程序中。当然回调也可以实现,但是使用协程会使程序“看起来”是顺序执行的。
我利用linux系统函数getcontext,makecontext,swapcontext来实现协程之间的切换。
getcontext(ucontext_t*)初始化上下文,makecontext(ucontext_t*,void(*)(void),int args)绑定切换到该上下文时的执行函数和函数的参数;swapcontext(ucontext_t* ouc,ucontext_t* uc);在ouc保存当前上下文,并切换到uc
typedef struct ucontext {
struct ucontext *uc_link; //指向当当前上下文执行完毕后要切换的上下文
sigset_t uc_sigmask;
stack_t uc_stack; //当前上下文使用的堆栈
mcontext_t uc_mcontext;
…
} ucontext_t;
我利用这些实现了一个简单的协程,在一个线程中实现多个协程之间的切换,使用epoll等待IO事件或定时器事件发生后对协程唤醒,当没有事件发生时进行轮询顺序切换到多个协程。下面是线程工作函数:
void MyUThread::thread_work(void* pvoid) {
MyUThread* mut = (MyUThread*)pvoid;
struct epoll_event events[1024];
while (!mut->shut_down_) {
int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
if (ret > 0) {//有事件发生
for (int i = 0; i < ret; i++) {
UContextRevent* ctxex = (UContextRevent*)events[i].data.ptr;
auto ctx = ctxex->ctx_;
if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
ctxex->revents_ = UThreadTimerOut;
}else {
ctxex->revents_ = events[i].events;
}
ctx->resume();
if (ctx->status_ == UContext::FINISHED) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}else if (ret == -1) {//异常
if (errno == EINTR) {
continue;
}
else {
perror("epoll_wait:");
abort();
}
}else {//epoll_wait超时
++mut->curr_running_index_;
if (mut->curr_running_index_ >= mut->curr_max_count_) {
mut->curr_running_index_ = -1;
continue;
}
auto ctx = mut->ctx_items_[mut->curr_running_index_];
if (ctx == nullptr || !ctx->register_fd_set.empty()) {
//如果该ctx协程正在等待某个异步事件,则不对他唤醒
continue;
}
ctx->resume();
if (ctx->status_ == UContext::FINISHED) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}
}
使用:
#include <iostream>
#include <functional>
#include <vector>
#include "MyUThread.h"
#include "MyTimer.h"
using namespace std;
using namespace placeholders;
void test1(UContext*& ctx, int num) {
cout << "enter test1:" << num << endl;
MyTimer timer;
timer.set_once(2, 0);//设置定时器2s
UContextRevent revents(ctx,timer.get_timerfd());
ctx->register_event(timer.get_timerfd(),EPOLLIN,&revents);//注册定时器事件
ctx->attach_timer(&timer);
cout << "start yield" << endl;
ctx->yield(); //让出cpu,等待定时器事件
if (revents.revents_ == UThreadTimerOut) {
cout << "UThreadTimerOut" << endl;
}
ctx->remove_event(timer.get_timerfd());
cout << "finished test1" << endl;
}
void test2(UContext*& ctx) {
cout << "test2 yield" << endl;
ctx->yield();
cout << "finished test2" << endl;
}
void finished() {
cout << "test1 callback" << endl;
}
int main()
{
MyUThread mut(2, 8192);//参数:协程最大数量,协程栈大小
mut.add_task(bind(test1, _1, 99999),finished);//添加任务,参数:工作函数,回调函数
for (int i = 0; i < 10; i++) {
mut.add_task(test2);
}
mut.join();
return 0;
}
下面贴出源代码:我还是个小白,把代码贴出来,如果有什么错误希望大家评论告诉我,谢谢(ㅎ-ㅎ;)
UContext.h
#include <functional>
#include <ucontext.h>
#include <memory>
#include <unordered_set>
using namespace std;
/*
EPOLLIN:1
EPOLLOUT:4
EPOLLRDHUP:8192
EPOLLPRI:2
EPOLLERR:8
EPOLLHUP:16
定时器超时标志
*/
#define UThreadTimerOut 3
class UContext;
typedef function<void(UContext*&)> UContextFunc;
typedef function<void()> CallBack;
class MyTimer;
struct UContextRevent {
UContext* ctx_;
int fd_;
int revents_;
UContextRevent(UContext* ctx, int fd,int revents = 0) :
ctx_(ctx), fd_(fd), revents_(revents){}
};
class UContext {
public:
UContext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd);
~UContext();
void set_func(UContextFunc func, CallBack callback = 0);
void make();
void resume();
void yield();
void attach_timer(MyTimer*);
void register_event(int fd, int events, UContextRevent* revents);
void remove_event(int fd);
MyTimer* get_timer();
private:
static void work_func(uint32_t low32, uint32_t high32);
public:
int ctx_index_; //当前对象在MyUThread对象的数组中保存的索引
ucontext_t* main_ctx_; //线程主逻辑上下文
ucontext_t* ctx_; //当前上下文
char* raw_stack_; //上下文使用栈空间
char* stack_; //栈空间(安全保护)
UContextFunc func_; //用户任务
CallBack callback_; //任务回调
int stack_size_; //栈大小
enum { READY = 0, RUNNING, SUSPEND, FINISHED };
int status_; //当前协程状态
int epoll_fd_; //epoll fd
MyTimer* mytimer_;
unordered_set<int> register_fd_set;
};
MyUThread.h
#include <thread>
#include <condition_variable>
#include <mutex>
#include <vector>
#include <queue>
#include <ucontext.h>
#include "sys/epoll.h"
using namespace std;
#include "UContext.h"
class MyUThread {
public:
MyUThread(int max_uthread_count, int stack_size);
~MyUThread();
void add_task(UContextFunc func, CallBack callback = 0);
void join();
void destory();
private:
void remove_uthread(int index);
int get_stack_size(int stack_size);
static void thread_work(void* pvoid);
private:
int max_uthread_count_; //最大协程数量
int stack_size_; //栈大小
volatile int curr_running_index_; //当前执行协程索引
volatile int curr_max_count_; //当前最大协程数量
volatile int idle_count_; //可用协程数量
bool shut_down_; //是否退出
vector<UContext*> ctx_items_; //调度的协程列表
queue<UContext*> ctx_ready_queue_; //就绪队列,等待其它协程退出后,会被添加到ctx_items_中
thread* thread_; //线程
ucontext_t main_ctx_; //主上下文
mutex mutex_; //锁queue
mutex mutex_join_;
condition_variable cv_; //使用锁和条件变量,阻塞等待协程全部执行完毕
int epoll_fd_; //epoll fd
};
MyTimer.h
#pragma once
#include <iostream>
#include <memory>
using namespace std;
class UContext;
class MyTimer
{
public:
MyTimer();
~MyTimer();
void set_once(int seconds, int millseconds);
void set_cycle(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds);
int get_timerfd();
void stop();
//获取距离时间到期还有多少纳秒
int get_time();
private:
void set_time(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds);
private:
int timerfd_;
struct itimerspec* timespec_;
};
UContext.cpp
#include "UContext.h"
#include <sys/mman.h>
#include <unistd.h>
#include <assert.h>
#include <sys/epoll.h>
#include "MyTimer.h"
UContext::UContext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd)
:ctx_index_(index), main_ctx_(main_ctx), ctx_(nullptr),
raw_stack_(nullptr), stack_(nullptr), func_(nullptr), callback_(nullptr),
stack_size_(stack_size),
status_(READY), epoll_fd_(epoll_fd),
mytimer_(nullptr){
//创建协程私有栈
auto page_size = getpagesize();
raw_stack_ = (char*)mmap(nullptr, stack_size_ + page_size * 2,
PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
assert(raw_stack_ != nullptr);
stack_ = raw_stack_ + page_size;
assert(mprotect(raw_stack_, page_size, PROT_NONE) == 0);
assert(mprotect(raw_stack_ + stack_size_ + page_size, page_size, PROT_NONE) == 0);
ctx_ = new ucontext_t;
ctx_->uc_flags = 0;
ctx_->uc_link = main_ctx;
ctx_->uc_stack.ss_sp = stack_;
ctx_->uc_stack.ss_size = stack_size_;
getcontext(ctx_);
}
UContext::~UContext() {
delete ctx_;
munmap(raw_stack_, stack_size_ + getpagesize() * 2);
}
void UContext::set_func(UContextFunc func, CallBack callback) {
func_ = func;
callback_ = callback;
}
void UContext::make() {
auto ptr = (uintptr_t)this;
makecontext(ctx_, (void(*)(void))work_func, 2, (uint32_t)ptr, (uint32_t)(ptr >> 32));
}
void UContext::resume() {
status_ = RUNNING;
swapcontext(main_ctx_, ctx_);
}
void UContext::yield() {
status_ = SUSPEND;
swapcontext(ctx_, main_ctx_);
}
void UContext::attach_timer(MyTimer* timer) {
mytimer_ = timer;
}
MyTimer* UContext::get_timer() {
return mytimer_;
}
void UContext::register_event(int fd, int events, UContextRevent* revents) {
epoll_event ev;
ev.data.fd = fd;
ev.events = events;
ev.data.ptr = revents;
epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &ev);
register_fd_set.insert(fd);
}
void UContext::remove_event(int fd) {
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr);
if (mytimer_ && mytimer_->get_timerfd() == fd) {
this->attach_timer(nullptr);
}
register_fd_set.erase(fd);
}
void UContext::work_func(uint32_t low32, uint32_t high32) {
uintptr_t ptr = (uintptr_t)low32 | ((uintptr_t)high32 << 32);
UContext * uc = (UContext*)ptr;
if (uc->func_) {
uc->func_(uc);
if (uc->callback_) {
uc->callback_();
}
}
uc->status_ = FINISHED;
}
MyUThread.cpp
#include "MyUThread.h"
#include "MyTimer.h"
#include <iostream>
#include <cstring>
#include <unistd.h>
#include <assert.h>
#include <functional>
#include <sys/timerfd.h>
using namespace std;
MyUThread::MyUThread(int max_uthread_count, int stack_size) :
max_uthread_count_(max_uthread_count),
stack_size_(get_stack_size(stack_size)),
curr_running_index_(-1),
curr_max_count_(0),
idle_count_(max_uthread_count),
shut_down_(false), thread_(0), epoll_fd_(0)
{
memset(&main_ctx_, 0, sizeof(main_ctx_));
ctx_items_.resize(max_uthread_count_, 0);
epoll_fd_ = epoll_create(100);
thread_ = new thread(bind(thread_work, (void*)this));
}
MyUThread::~MyUThread() {
ctx_items_.clear();
while (!ctx_ready_queue_.empty()) {
ctx_ready_queue_.pop();
}
delete thread_;
}
void MyUThread::add_task(UContextFunc func, CallBack callback) {
auto ctx = new UContext(-1, &main_ctx_, stack_size_,epoll_fd_);
ctx->set_func(func, callback);
if (curr_max_count_ < max_uthread_count_) {
ctx->ctx_index_ = curr_max_count_;
ctx->make();
ctx_items_[curr_max_count_] = ctx;
++curr_max_count_;
--idle_count_;
}
else {
if (idle_count_ > 0) {
auto index = -1;
for (int i = 0, j = curr_max_count_ - 1; i <= j; i++, j--) {
if (ctx_items_[i] == nullptr) {
index = i;
break;
}
if (ctx_items_[j] == nullptr) {
index = j;
break;
}
}
ctx->ctx_index_ = index;
ctx->make();
ctx_items_[index] = ctx;
--idle_count_;
}
else {
lock_guard<mutex> lock(mutex_);
ctx_ready_queue_.push(ctx);
}
}
}
void MyUThread::join() {
while (idle_count_ < max_uthread_count_ || !ctx_ready_queue_.empty()) {
unique_lock<mutex> lock(mutex_join_);
cv_.wait(lock);
}
this->shut_down_ = true;
this->thread_->join();
}
void MyUThread::destory() {
this->shut_down_ = true;
this->thread_->join();
}
void MyUThread::remove_uthread(int index) {
if (index >= 0 && index < curr_max_count_) {
delete ctx_items_[index];
ctx_items_[index] = nullptr;
++idle_count_;
cv_.notify_all();
}
UContext* ctx = nullptr;
{
lock_guard<mutex> lock(mutex_);
if (!ctx_ready_queue_.empty()) {
ctx = ctx_ready_queue_.front();
ctx_ready_queue_.pop();
}
}
if (ctx) {
ctx->ctx_index_ = index;
ctx->make();
ctx_items_[index] = ctx;
--idle_count_;
}
}
int MyUThread::get_stack_size(int stack_size) {
auto page_size = getpagesize();
if (stack_size < page_size) {
return page_size;
}
int page_count = stack_size / page_size;
if ((stack_size % page_size) > 0) {
++page_count;
}
return page_count * page_size;
}
void MyUThread::thread_work(void* pvoid) {
MyUThread* mut = (MyUThread*)pvoid;
struct epoll_event events[1024];
while (!mut->shut_down_) {
int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
if (ret > 0) {//有事件发生
for (int i = 0; i < ret; i++) {
UContextRevent* ctxex = (UContextRevent*)events[i].data.ptr;
auto ctx = ctxex->ctx_;
if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
ctxex->revents_ = UThreadTimerOut;
}else {
ctxex->revents_ = events[i].events;
}
ctx->resume();
if (ctx->status_ == UContext::FINISHED) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}else if (ret == -1) {//异常
if (errno == EINTR) {
continue;
}
else {
perror("epoll_wait:");
abort();
}
}else {//epoll_wait超时
++mut->curr_running_index_;
if (mut->curr_running_index_ >= mut->curr_max_count_) {
mut->curr_running_index_ = -1;
continue;
}
auto ctx = mut->ctx_items_[mut->curr_running_index_];
if (ctx == nullptr || !ctx->register_fd_set.empty()) {
//如果该ctx协程正在等待某个异步事件,则不对他唤醒
continue;
}
ctx->resume();
if (ctx->status_ == UContext::FINISHED) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}
}
MyTimer.cpp
#include "MyTimer.h"
#include "UContext.h"
#include <cstring>
#include <unistd.h>
#include <sys/epoll.h>
#include <sys/timerfd.h>
#include <fcntl.h>
MyTimer::MyTimer():timerfd_(-1), timespec_(nullptr)
{
timerfd_ = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK);
if (timerfd_ == -1) {
abort();
}
auto flags = fcntl(timerfd_, F_GETFL, 0);
flags |= O_NONBLOCK;
fcntl(timerfd_, F_SETFL, flags);
timespec_ = new struct itimerspec;
memset(timespec_, 0, sizeof(struct itimerspec));
}
MyTimer::~MyTimer()
{
if (timespec_) {
delete timespec_;
}
close(timerfd_);
}
void MyTimer::set_once(int seconds, int millseconds) {
this->set_time(seconds, millseconds, 0, 0);
}
void MyTimer::set_cycle(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds) {
this->set_time(seconds, millseconds, intervalSeconds, intervalMillSeconds);
}
int MyTimer::get_timerfd() {
return timerfd_;
}
void MyTimer::set_time(int seconds, int millseconds, int intervalSeconds, int intervalMillSeconds) {
if (timerfd_ == -1 || timespec_ == nullptr)
return;
timespec_->it_value.tv_sec = seconds;
timespec_->it_value.tv_nsec = millseconds * 1000;
timespec_->it_interval.tv_sec = intervalSeconds;
timespec_->it_interval.tv_nsec = intervalMillSeconds * 1000;
if (-1 == timerfd_settime(timerfd_, 0, timespec_, nullptr)) {
abort();
}
}
void MyTimer::stop() {
this->set_time(0, 0, 0, 0);
}
int MyTimer::get_time() {
struct itimerspec t;
timerfd_gettime(timerfd_, &t);
return t.it_value.tv_sec * 10001000 + t.it_value.tv_nsec;
}
这就是我简单编写的全部内容,欢迎指正~