• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

emqx / esockd / 315

15 Sep 2023 01:50PM UTC coverage: 70.531% (+0.2%) from 70.291%
315

push

github

zmstone
chore: load esockd_connection_sup in appup

730 of 1035 relevant lines covered (70.53%)

176.59 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.94
/src/esockd_connection_sup.erl
1
%%--------------------------------------------------------------------
2
%% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved.
3
%%
4
%% Licensed under the Apache License, Version 2.0 (the "License");
5
%% you may not use this file except in compliance with the License.
6
%% You may obtain a copy of the License at
7
%%
8
%%     http://www.apache.org/licenses/LICENSE-2.0
9
%%
10
%% Unless required by applicable law or agreed to in writing, software
11
%% distributed under the License is distributed on an "AS IS" BASIS,
12
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
%% See the License for the specific language governing permissions and
14
%% limitations under the License.
15
%%--------------------------------------------------------------------
16

17
-module(esockd_connection_sup).
18

19
-behaviour(gen_server).
20

21
-import(proplists, [get_value/3]).
22

23
-export([start_link/2, stop/1]).
24

25
-export([ start_connection/3
26
        , count_connections/1
27
        ]).
28

29
-export([ get_max_connections/1
30
        , set_max_connections/2
31
        ]).
32

33
-export([get_shutdown_count/1]).
34

35
%% Allow, Deny
36
-export([ access_rules/1
37
        , allow/2
38
        , deny/2
39
        ]).
40

41
%% gen_server callbacks
42
-export([ init/1
43
        , handle_call/3
44
        , handle_cast/2
45
        , handle_info/2
46
        , terminate/2
47
        , code_change/3
48
        ]).
49

50
-type(shutdown() :: brutal_kill | infinity | pos_integer()).
51

52
-record(state, {
53
          curr_connections :: map(),
54
          max_connections :: pos_integer(),
55
          access_rules :: list(),
56
          shutdown :: shutdown(),
57
          mfargs :: esockd:mfargs()
58
         }).
59

60
-define(DEFAULT_MAX_CONNS, 1024).
61
-define(TRANSPORT, esockd_transport).
62
-define(ERROR_MSG(Format, Args),
63
        error_logger:error_msg("[~s] " ++ Format, [?MODULE | Args])).
64

65
%% @doc Start connection supervisor.
66
-spec(start_link([esockd:option()], esockd:mfargs())
67
      -> {ok, pid()} | ignore | {error, term()}).
68
start_link(Opts, MFA) ->
69
    gen_server:start_link(?MODULE, [Opts, MFA], []).
171✔
70

71
-spec(stop(pid()) -> ok).
72
stop(Pid) -> gen_server:stop(Pid).
18✔
73

74
%%--------------------------------------------------------------------
75
%% API
76
%%--------------------------------------------------------------------
77

78
%% @doc Start connection.
79
start_connection(Sup, Sock, UpgradeFuns) ->
80
    case call(Sup, {start_connection, Sock}) of
132✔
81
        {ok, ConnPid} ->
82
            %% Transfer controlling from acceptor to connection
83
            _ = ?TRANSPORT:controlling_process(Sock, ConnPid),
132✔
84
            _ = ?TRANSPORT:ready(ConnPid, Sock, UpgradeFuns),
132✔
85
            {ok, ConnPid};
132✔
86
        ignore -> ignore;
×
87
        {error, Reason} ->
88
            {error, Reason}
×
89
    end.
90

91
%% @doc Start the connection process.
92
-spec(start_connection_proc(esockd:mfargs(), esockd_transport:socket())
93
      -> {ok, pid()} | ignore | {error, term()}).
94
start_connection_proc(M, Sock) when is_atom(M) ->
95
    M:start_link(?TRANSPORT, Sock);
×
96
start_connection_proc({M, F}, Sock) when is_atom(M), is_atom(F) ->
97
    M:F(?TRANSPORT, Sock);
×
98
start_connection_proc({M, F, Args}, Sock) when is_atom(M), is_atom(F), is_list(Args) ->
99
    erlang:apply(M, F, [?TRANSPORT, Sock | Args]).
132✔
100

101
-spec(count_connections(pid()) -> integer()).
102
count_connections(Sup) ->
103
    call(Sup, count_connections).
12✔
104

105
-spec(get_max_connections(pid()) -> integer()).
106
get_max_connections(Sup) when is_pid(Sup) ->
107
    call(Sup, get_max_connections).
21✔
108

109
-spec(set_max_connections(pid(), integer()) -> ok).
110
set_max_connections(Sup, MaxConns) when is_pid(Sup) ->
111
    call(Sup, {set_max_connections, MaxConns}).
9✔
112

113
-spec(get_shutdown_count(pid()) -> [{atom(), integer()}]).
114
get_shutdown_count(Sup) ->
115
    call(Sup, get_shutdown_count).
12✔
116

117
access_rules(Sup) ->
118
    call(Sup, access_rules).
33✔
119

120
allow(Sup, CIDR) ->
121
    call(Sup, {add_rule, {allow, CIDR}}).
12✔
122

123
deny(Sup, CIDR) ->
124
    call(Sup, {add_rule, {deny, CIDR}}).
12✔
125

126
call(Sup, Req) ->
127
    gen_server:call(Sup, Req, infinity).
243✔
128

129
%%--------------------------------------------------------------------
130
%% gen_server callbacks
131
%%--------------------------------------------------------------------
132

133
init([Opts, MFA]) ->
134
    process_flag(trap_exit, true),
171✔
135
    Shutdown = get_value(shutdown, Opts, brutal_kill),
171✔
136
    MaxConns = get_value(max_connections, Opts, ?DEFAULT_MAX_CONNS),
171✔
137
    RawRules = get_value(access_rules, Opts, [{allow, all}]),
171✔
138
    AccessRules = [esockd_access:compile(Rule) || Rule <- RawRules],
171✔
139
    {ok, #state{curr_connections = #{},
171✔
140
                max_connections  = MaxConns,
141
                access_rules     = AccessRules,
142
                shutdown         = Shutdown,
143
                mfargs           = MFA}}.
144

145
handle_call({start_connection, _Sock}, _From,
146
            State = #state{curr_connections = Conns, max_connections = MaxConns})
147
    when map_size(Conns) >= MaxConns ->
148
    {reply, {error, maxlimit}, State};
×
149

150
handle_call({start_connection, Sock}, _From,
151
            State = #state{curr_connections = Conns, access_rules = Rules, mfargs = MFA}) ->
152
    case esockd_transport:peername(Sock) of
132✔
153
        {ok, {Addr, _Port}} ->
154
            case allowed(Addr, Rules) of
132✔
155
                true ->
156
                    try start_connection_proc(MFA, Sock) of
132✔
157
                        {ok, Pid} when is_pid(Pid) ->
158
                            NState = State#state{curr_connections = maps:put(Pid, true, Conns)},
132✔
159
                            {reply, {ok, Pid}, NState};
132✔
160
                        ignore ->
161
                            {reply, ignore, State};
×
162
                        {error, Reason} ->
163
                            {reply, {error, Reason}, State}
×
164
                    catch
165
                        _Error:Reason:ST ->
166
                            {reply, {error, {Reason, ST}}, State}
×
167
                    end;
168
                false ->
169
                    {reply, {error, forbidden}, State}
×
170
            end;
171
        {error, Reason} ->
172
            {reply, {error, Reason}, State}
×
173
    end;
174

175
handle_call(count_connections, _From, State = #state{curr_connections = Conns}) ->
176
    {reply, maps:size(Conns), State};
12✔
177

178
handle_call(get_max_connections, _From, State = #state{max_connections = MaxConns}) ->
179
    {reply, MaxConns, State};
21✔
180

181
handle_call({set_max_connections, MaxConns}, _From, State) ->
182
    {reply, ok, State#state{max_connections = MaxConns}};
9✔
183

184
handle_call(get_shutdown_count, _From, State) ->
185
    Counts = [{Reason, Count} || {{shutdown_count, Reason}, Count} <- get()],
12✔
186
    {reply, Counts, State};
12✔
187

188
handle_call(access_rules, _From, State = #state{access_rules = Rules}) ->
189
    {reply, [raw(Rule) || Rule <- Rules], State};
33✔
190

191
handle_call({add_rule, RawRule}, _From, State = #state{access_rules = Rules}) ->
192
    try esockd_access:compile(RawRule) of
24✔
193
        Rule ->
194
            case lists:member(Rule, Rules) of
24✔
195
                true ->
196
                    {reply, {error, already_exists}, State};
×
197
                false ->
198
                    {reply, ok, State#state{access_rules = [Rule | Rules]}}
24✔
199
            end
200
    catch
201
        error:Reason ->
202
            error_logger:error_msg("Bad access rule: ~p, compile errro: ~p", [RawRule, Reason]),
×
203
            {reply, {error, bad_access_rule}, State}
×
204
    end;
205

206
%% mimic the supervisor's which_children reply
207
handle_call(which_children, _From, State = #state{curr_connections = Conns, mfargs = MFA}) ->
208
    Mod = get_module(MFA),
×
209
    {reply, [{undefined, Pid, worker, [Mod]}
×
210
              || Pid <- maps:keys(Conns), erlang:is_process_alive(Pid)], State};
×
211

212
handle_call(Req, _From, State) ->
213
    ?ERROR_MSG("Unexpected call: ~p", [Req]),
3✔
214
    {reply, ignore, State}.
3✔
215

216
handle_cast(Msg, State) ->
217
    ?ERROR_MSG("Unexpected cast: ~p", [Msg]),
3✔
218
    {noreply, State}.
3✔
219

220
handle_info({'EXIT', Pid, Reason}, State = #state{curr_connections = Conns}) ->
221
    case maps:take(Pid, Conns) of
69✔
222
        {true, Conns1} ->
223
            connection_crashed(Pid, Reason, State),
69✔
224
            {noreply, State#state{curr_connections = Conns1}};
69✔
225
        error ->
226
            ?ERROR_MSG("Unexpected 'EXIT': ~p, reason: ~p", [Pid, Reason]),
×
227
            {noreply, State}
×
228
    end;
229

230
handle_info(Info, State) ->
231
    ?ERROR_MSG("Unexpected info: ~p", [Info]),
3✔
232
    {noreply, State}.
3✔
233

234
terminate(_Reason, State) ->
235
    terminate_children(State).
171✔
236

237
code_change(_OldVsn, State, _Extra) ->
238
    {ok, State}.
×
239

240
%%--------------------------------------------------------------------
241
%% Internal functions
242
%%--------------------------------------------------------------------
243

244
allowed(Addr, Rules) ->
245
    case esockd_access:match(Addr, Rules) of
132✔
246
        nomatch          -> true;
×
247
        {matched, allow} -> true;
132✔
248
        {matched, deny}  -> false
×
249
    end.
250

251
raw({allow, CIDR = {_Start, _End, _Len}}) ->
252
     {allow, esockd_cidr:to_string(CIDR)};
54✔
253
raw({deny, CIDR = {_Start, _End, _Len}}) ->
254
     {deny, esockd_cidr:to_string(CIDR)};
12✔
255
raw(Rule) ->
256
     Rule.
×
257

258
connection_crashed(_Pid, normal, _State) ->
259
    ok;
12✔
260
connection_crashed(_Pid, shutdown, _State) ->
261
    ok;
×
262
connection_crashed(_Pid, killed, _State) ->
263
    ok;
×
264
connection_crashed(_Pid, Reason, _State) when is_atom(Reason) ->
265
    count_shutdown(Reason);
×
266
connection_crashed(_Pid, {shutdown, Reason}, _State) when is_atom(Reason) ->
267
    count_shutdown(Reason);
48✔
268
connection_crashed(Pid, {shutdown, {ssl_error, Reason}}, State) ->
269
    count_shutdown(ssl_error),
3✔
270
    log(info, ssl_error, Reason, Pid, State);
3✔
271
connection_crashed(Pid, {shutdown, #{shutdown_count := Key} = Reason}, State) when is_atom(Key) ->
272
    count_shutdown(Key),
3✔
273
    log(info, Key, maps:without([shutdown_count], Reason), Pid, State);
3✔
274
connection_crashed(Pid, {shutdown, Reason}, State) ->
275
    %% unidentified shutdown, cannot keep a counter of it,
276
    %% ideally we should try to add a 'shutdown_count' filed to the reason.
277
    log(error, connection_shutdown, Reason, Pid, State);
×
278
connection_crashed(Pid, Reason, State) ->
279
    %% unexpected crash, probably deserve a fix
280
    log(error, connection_crashed, Reason, Pid, State).
3✔
281

282
count_shutdown(Reason) ->
283
    Key = {shutdown_count, Reason},
54✔
284
    put(Key, case get(Key) of undefined -> 1; Cnt -> Cnt+1 end).
54✔
285

286
terminate_children(State = #state{curr_connections = Conns, shutdown = Shutdown}) ->
287
    {Pids, EStack0} = monitor_children(Conns),
171✔
288
    Sz = sets:size(Pids),
171✔
289
    EStack = case Shutdown of
171✔
290
                 brutal_kill ->
291
                     sets:fold(fun(P, _) -> exit(P, kill) end, ok, Pids),
171✔
292
                     wait_children(Shutdown, Pids, Sz, undefined, EStack0);
171✔
293
                 infinity ->
294
                     sets:fold(fun(P, _) -> exit(P, shutdown) end, ok, Pids),
×
295
                     wait_children(Shutdown, Pids, Sz, undefined, EStack0);
×
296
                 Time when is_integer(Time) ->
297
                     sets:fold(fun(P, _) -> exit(P, shutdown) end, ok, Pids),
×
298
                     TRef = erlang:start_timer(Time, self(), kill),
×
299
                     wait_children(Shutdown, Pids, Sz, TRef, EStack0)
×
300
             end,
301
    %% Unroll stacked errors and report them
302
    dict:fold(fun(Reason, Pid, _) ->
171✔
303
                  log(error, connection_shutdown_error, Reason, Pid, State)
×
304
              end, ok, EStack).
305

306
monitor_children(Conns) ->
307
    lists:foldl(fun(P, {Pids, EStack}) ->
171✔
308
        case monitor_child(P) of
63✔
309
            ok ->
310
                {sets:add_element(P, Pids), EStack};
63✔
311
            {error, normal} ->
312
                {Pids, EStack};
×
313
            {error, Reason} ->
314
                {Pids, dict:append(Reason, P, EStack)}
×
315
        end
316
    end, {sets:new(), dict:new()}, maps:keys(Conns)).
317

318
%% Help function to shutdown/2 switches from link to monitor approach
319
monitor_child(Pid) ->
320
    %% Do the monitor operation first so that if the child dies
321
    %% before the monitoring is done causing a 'DOWN'-message with
322
    %% reason noproc, we will get the real reason in the 'EXIT'-message
323
    %% unless a naughty child has already done unlink...
324
    erlang:monitor(process, Pid),
63✔
325
    unlink(Pid),
63✔
326

327
    receive
63✔
328
        %% If the child dies before the unlik we must empty
329
        %% the mail-box of the 'EXIT'-message and the 'DOWN'-message.
330
        {'EXIT', Pid, Reason} ->
331
            receive
×
332
                {'DOWN', _, process, Pid, _} ->
333
                    {error, Reason}
×
334
            end
335
    after 0 ->
336
            %% If a naughty child did unlink and the child dies before
337
            %% monitor the result will be that shutdown/2 receives a
338
            %% 'DOWN'-message with reason noproc.
339
            %% If the child should die after the unlink there
340
            %% will be a 'DOWN'-message with a correct reason
341
            %% that will be handled in shutdown/2.
342
            ok
63✔
343
    end.
344

345
wait_children(_Shutdown, _Pids, 0, undefined, EStack) ->
346
    EStack;
171✔
347
wait_children(_Shutdown, _Pids, 0, TRef, EStack) ->
348
        %% If the timer has expired before its cancellation, we must empty the
349
        %% mail-box of the 'timeout'-message.
350
    _ = erlang:cancel_timer(TRef),
×
351
    receive
×
352
        {timeout, TRef, kill} ->
353
            EStack
×
354
    after 0 ->
355
            EStack
×
356
    end;
357

358
%%TODO: Copied from supervisor.erl, rewrite it later.
359
wait_children(brutal_kill, Pids, Sz, TRef, EStack) ->
360
    receive
63✔
361
        {'DOWN', _MRef, process, Pid, killed} ->
362
            wait_children(brutal_kill, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);
63✔
363

364
        {'DOWN', _MRef, process, Pid, Reason} ->
365
            wait_children(brutal_kill, sets:del_element(Pid, Pids),
×
366
                          Sz-1, TRef, dict:append(Reason, Pid, EStack))
367
    end;
368

369
wait_children(Shutdown, Pids, Sz, TRef, EStack) ->
370
    receive
×
371
        {'DOWN', _MRef, process, Pid, shutdown} ->
372
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);
×
373
        {'DOWN', _MRef, process, Pid, normal} ->
374
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);
×
375
        {'DOWN', _MRef, process, Pid, Reason} ->
376
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1,
×
377
                          TRef, dict:append(Reason, Pid, EStack));
378
        {timeout, TRef, kill} ->
379
            sets:fold(fun(P, _) -> exit(P, kill) end, ok, Pids),
×
380
            wait_children(Shutdown, Pids, Sz-1, undefined, EStack)
×
381
    end.
382

383
log(Level, Error, Reason, Pid, #state{mfargs = MFA}) ->
384
    ErrorMsg = [{supervisor, {?MODULE, Pid}},
9✔
385
                {errorContext, Error},
386
                {reason, Reason},
387
                {offender, [{pid, Pid},
388
                            {name, connection},
389
                            {mfargs, MFA}]}],
390
    case Level of
9✔
391
        info ->
392
            error_logger:info_report(supervisor_report, ErrorMsg);
6✔
393
        error ->
394
            error_logger:error_report(supervisor_report, ErrorMsg)
3✔
395
    end.
396

397
get_module({M, _F, _A}) -> M;
×
398
get_module({M, _F}) -> M;
×
399
get_module(M) -> M.
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc