Task-自定义任务

前言

本篇会研究如何在c#里实现自定义的Task, 一般来说c#默认的Task基本能实现绝大部分的异步需求, 为什么需要自定义Task呢?

  • 性能优化: 如果你的任务是高频率的小任务,使用 ValueTask 或自定义的轻量级任务类型可以减少内存分配和 GC 压力。
  • 灵活性:你可以根据需要设计任务的行为,比如自定义错误处理、超时机制、取消支持等。
  • 并发控制:有时你需要控制任务的执行方式,比如限制最大并发数,或者为任务设置特定的优先级,这时自定义类型可能会更适合。
  • 组合和调度:在复杂的异步操作中,可能需要将多个任务组合在一起执行并管理任务的依赖关系。通过自定义任务类型,你可以更好地控制任务的执行顺序和逻辑。

1. Task Type1

任务类型的组成:

  1. 一个classStruct并带有System.Runtime.CompilerServices.AsyncMethodBuilderAttribute属性。
1
2
3
4
5
6
[AsyncMethodBuilder(typeof(MyTaskMethodBuilder<>))]
class MyTask<T>
{
    // 为了支持`Await`, `task type`必须有一个可访问的`GetAwaiter`方法返回`awaiter type`。
    public Awaiter<T> GetAwaiter();
}
  1. Awaiter type: 一个可等待的类型必须实现INotifyCompletion并满足如下条件
1
2
3
4
5
6
7
8
9
class Awaiter<T> : INotifyCompletion
{
    // 判断异步操作是否已经完成
    public bool IsCompleted { get; }
    // 获取异步执行结果
    public T GetResult();
    // 注册一个回调函数,当异步操作完成时,调用这个回调。
    public void OnCompleted(Action completion);
}
  1. Builder Type: 一个与Task对应的classStruct, builder type最多只能有一个类型参数且不能嵌套在泛型类中。有以下public方法, 对于非泛型的builer type, SetResult方法没有参数:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class MyTaskMethodBuilder<T>
{
    public static MyTaskMethodBuilder<T> Create();

    public void Start<TStateMachine>(ref TStateMachine stateMachine)
        where TStateMachine : IAsyncStateMachine;

    public void SetStateMachine(IAsyncStateMachine stateMachine);
    public void SetException(Exception exception);
    public void SetResult(T result);

    public void AwaitOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine;

    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine;

    public MyTask<T> Task { get; }
}

2. 自定义任务实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
static async Task Main(string[] args)
{
    var result = await FooAsync();
    Console.WriteLine($"result={result}");
}

public static async MyTask<int> FooAsync()
{
    Console.WriteLine("mytask exec 1");
    await Task.Delay(1000);
    Console.WriteLine("mytask exec 2");
    await Task.Delay(1000);
    Console.WriteLine("mytask exec 3");
    return 999;
}

[AsyncMethodBuilder(typeof(MyTaskMethodBuilder<>))]
public class MyTask<T>
{
    private MyAwaiter<T> _awaiter;
    public T Value;
    public bool IsCompleted { get; private set; }

    public MyAwaiter<T> GetAwaiter()
    {
        Console.WriteLine("MyTask GetAwaiter");
        return _awaiter ??= new MyAwaiter<T>()
        {
            Task = this,
        };
    }

    public void SetResult(T result)
    {
        Console.WriteLine($"MyTask SetResult result={result}");
        Value = result;
        IsCompleted = true;
        _awaiter?.RunCallback();
        _awaiter = null;
    }
}

public class MyAwaiter<T> : INotifyCompletion
{
    public bool IsCompleted => Task.IsCompleted;
    public MyTask<T> Task { get; set; }

    private Action _continuation;

    public T GetResult()
    {
        Console.WriteLine("MyAwaiter GetResult");
        return Task.Value;
    }

    public void OnCompleted(Action completion)
    {
        Console.WriteLine("MyAwaiter OnCompleted");
        _continuation = completion;
    }

    public void RunCallback()
    {
        Console.WriteLine("MyAwaiter RunCallback");
        _continuation?.Invoke();
    }
}

public class MyTaskMethodBuilder<T>
{
    public static MyTaskMethodBuilder<T> Create()
    {
        Console.WriteLine($"MyTaskMethodBuilder Create type={typeof(T).Name}");
        return new MyTaskMethodBuilder<T>()
        {
            Task = new MyTask<T>()
        };
    }

    public void Start<TStateMachine>(ref TStateMachine stateMachine)
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine($"MyTaskMethodBuilder Start");
        stateMachine.MoveNext();
    }

    public void SetStateMachine(IAsyncStateMachine stateMachine)
    {
        Console.WriteLine($"MyTaskMethodBuilder SetStateMachine stateMachine={stateMachine}");
    }

    public void SetException(Exception exception)
    {
        Console.WriteLine($"MyTaskMethodBuilder SetException exception={exception}");
    }

    public void SetResult(T result)
    {
        Console.WriteLine($"MyTaskMethodBuilder SetResult result={result}");
        Task.SetResult(result);
    }

    public void AwaitOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine($"MyTaskMethodBuilder AwaitOnCompleted");
        awaiter.OnCompleted(stateMachine.MoveNext);
    }

    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine($"MyTaskMethodBuilder AwaitUnsafeOnCompleted");
        Console.WriteLine($"{awaiter.GetType().Name}");
        awaiter.OnCompleted(stateMachine.MoveNext);
    }

    public MyTask<T> Task { get; private set; }
}

上面的代码展示了一个简单的自定义任务实现, 运行代码可以看到如下打印:

  1. MyTaskMethodBuilder Create type=Int32
  2. MyTaskMethodBuilder Start
  3. mytask exec 1
  4. MyTaskMethodBuilder AwaitUnsafeOnCompleted
  5. TaskAwaiter
  6. MyTask GetAwaiter
  7. MyAwaiter OnCompleted
  8. mytask exec 2
  9. MyTaskMethodBuilder AwaitUnsafeOnCompleted
  10. TaskAwaiter
  11. mytask exec 3
  12. MyTaskMethodBuilder SetResult result=999
  13. MyTask SetResult result=999
  14. MyAwaiter RunCallback
  15. MyAwaiter GetResult
  16. result=999

逐步分析打印:

  1. Main方法状态机执行var result = await FooAsync(); -> 创建一个MyTaskMethodBuilder类型。
  2. 执行FooAsync对应的状态机的Start方法。
  3. 执行FooAsyncConsole.WriteLine("mytask exec 1");
  4. 执行FooAsync的第一个await Task.Delay(1000);, 判断该任务未完成, 执行MyTaskMethodBuilder里的AwaitUnsafeOnCompleted方法。
  5. AwaitUnsafeOnCompleted里打印了该方法传入的awaiter类型, 为TaskAwaiter, 注意到这是c#默认的Task对应的Awaiter, await后面等待的是什么任务类型, 状态机的后续驱动 就由该任务类型来调用执行。
  6. 这里在获取自定义的MyAwaiter, 因为上一步FooAsync状态机执行完成了, 进入了异步的等待状态, 所以这一步其实是回到了Main方法对应的状态机, 因为await后是我们的自定义任务类型, 所以获取了自定义的MyAwaiter
  7. 还是Main方法的状态机执行, 判断到MyAwaiterIsCompleted没有完成, 因此Main方法的状态机调用了MyAwaiterOnCompleted方法, 并传入Main状态机的MoveNext方法作为FooAsync方法执行完成的回调。
  8. FooAsync的第一个await执行完成, 由默认的Task执行FooAsyncMoveNext方法。
  9. 同4
  10. 同5
  11. 同8
  12. FooAsync执行完成, 对应最后的return 999;, 状态机调用MyTaskMethodBuilderSetResult方法。
  13. MyTask执行SetResult方法, 这里会将IsCompleted属性标记为完成状态, 同时执行’continue action’回调。
  14. 上一步调用了MyAwaiterRunCallback方法。方法里执行第7步传入的回调。
  15. 执行Main状态机的MoveNext方法, 获取异步结果并赋值给result参数。
  16. 执行Console.WriteLine($"result={result}");
0%