Add a separate buffer for SRTCP packets
[asterisk/asterisk.git] / res / res_srtp.c
index 546f073..a232314 100644 (file)
  */
 
 /*** MODULEINFO
-         <depend>srtp</depend>
+       <depend>srtp</depend>
+       <support_level>core</support_level>
 ***/
 
-/* The SIP channel will automatically use sdescriptions if received in a SDP offer,
-   and res_srtp is loaded. SRTP with sdescriptions key exchange can be activated
-  in outgoing offers by setting _SIPSRTP_CRYPTO=enable in extension.conf before executing Dial
-
-  The dial fails if the callee doesn't support SRTP and sdescriptions.
-
-  exten => 2345,1,Set(_SIPSRTP_CRYPTO=enable)
-  exten => 2345,2,Dial(SIP/1001)
-*/
+/* See https://wiki.asterisk.org/wiki/display/AST/Secure+Calling */
 
 #include "asterisk.h"
 
@@ -53,14 +46,17 @@ ASTERISK_FILE_VERSION(__FILE__, "$Revision$")
 #include "asterisk/module.h"
 #include "asterisk/options.h"
 #include "asterisk/rtp_engine.h"
+#include "asterisk/astobj2.h"
 
 struct ast_srtp {
        struct ast_rtp_instance *rtp;
+       struct ao2_container *policies;
        srtp_t session;
        const struct ast_srtp_cb *cb;
        void *data;
+       int warned;
        unsigned char buf[8192 + AST_FRIENDLY_OFFSET];
-       unsigned int has_stream:1;
+       unsigned char rtcpbuf[8192 + AST_FRIENDLY_OFFSET];
 };
 
 struct ast_srtp_policy {
@@ -73,6 +69,7 @@ static int g_initialized = 0;
 static int ast_srtp_create(struct ast_srtp **srtp, struct ast_rtp_instance *rtp, struct ast_srtp_policy *policy);
 static void ast_srtp_destroy(struct ast_srtp *srtp);
 static int ast_srtp_add_stream(struct ast_srtp *srtp, struct ast_srtp_policy *policy);
+static int ast_srtp_change_source(struct ast_srtp *srtp, unsigned int from_ssrc, unsigned int to_ssrc);
 
 static int ast_srtp_unprotect(struct ast_srtp *srtp, void *buf, int *len, int rtcp);
 static int ast_srtp_protect(struct ast_srtp *srtp, void **buf, int *len, int rtcp);
@@ -90,6 +87,7 @@ static struct ast_srtp_res srtp_res = {
        .create = ast_srtp_create,
        .destroy = ast_srtp_destroy,
        .add_stream = ast_srtp_add_stream,
+       .change_source = ast_srtp_change_source,
        .set_cb = ast_srtp_set_cb,
        .unprotect = ast_srtp_unprotect,
        .protect = ast_srtp_protect,
@@ -144,6 +142,32 @@ static const char *srtp_errstr(int err)
        }
 }
 
+static int policy_hash_fn(const void *obj, const int flags)
+{
+       const struct ast_srtp_policy *policy = obj;
+
+       return policy->sp.ssrc.type == ssrc_specific ? policy->sp.ssrc.value : policy->sp.ssrc.type;
+}
+
+static int policy_cmp_fn(void *obj, void *arg, int flags)
+{
+       const struct ast_srtp_policy *one = obj, *two = arg;
+
+       return one->sp.ssrc.type == two->sp.ssrc.type && one->sp.ssrc.value == two->sp.ssrc.value;
+}
+
+static struct ast_srtp_policy *find_policy(struct ast_srtp *srtp, const srtp_policy_t *policy, int flags)
+{
+       struct ast_srtp_policy tmp = {
+               .sp = {
+                       .ssrc.type = policy->ssrc.type,
+                       .ssrc.value = policy->ssrc.value,
+               },
+       };
+
+       return ao2_t_find(srtp->policies, &tmp, flags, "Looking for policy");
+}
+
 static struct ast_srtp *res_srtp_new(void)
 {
        struct ast_srtp *srtp;
@@ -153,6 +177,13 @@ static struct ast_srtp *res_srtp_new(void)
                return NULL;
        }
 
+       if (!(srtp->policies = ao2_t_container_alloc(5, policy_hash_fn, policy_cmp_fn, "SRTP policy container"))) {
+               ast_free(srtp);
+               return NULL;
+       }
+       
+       srtp->warned = 1;
+
        return srtp;
 }
 
@@ -188,11 +219,21 @@ static void ast_srtp_policy_set_ssrc(struct ast_srtp_policy *policy,
        }
 }
 
+static void policy_destructor(void *obj)
+{
+       struct ast_srtp_policy *policy = obj;
+
+       if (policy->sp.key) {
+               ast_free(policy->sp.key);
+               policy->sp.key = NULL;
+       }
+}
+
 static struct ast_srtp_policy *ast_srtp_policy_alloc()
 {
        struct ast_srtp_policy *tmp;
 
-       if (!(tmp = ast_calloc(1, sizeof(*tmp)))) {
+       if (!(tmp = ao2_t_alloc(sizeof(*tmp), policy_destructor, "Allocating policy"))) {
                ast_log(LOG_ERROR, "Unable to allocate memory for srtp_policy\n");
        }
 
@@ -201,11 +242,7 @@ static struct ast_srtp_policy *ast_srtp_policy_alloc()
 
 static void ast_srtp_policy_destroy(struct ast_srtp_policy *policy)
 {
-       if (policy->sp.key) {
-               ast_free(policy->sp.key);
-               policy->sp.key = NULL;
-       }
-       ast_free(policy);
+       ao2_t_ref(policy, -1, "Destroying policy");
 }
 
 static int policy_set_suite(crypto_policy_t *p, enum ast_srtp_suite suite)
@@ -282,8 +319,11 @@ static int ast_srtp_unprotect(struct ast_srtp *srtp, void *buf, int *len, int rt
 {
        int res = 0;
        int i;
+       int retry = 0;
        struct ast_rtp_instance_stats stats = {0,};
 
+       tryagain:
+
        for (i = 0; i < 2; i++) {
                res = rtcp ? srtp_unprotect_rtcp(srtp->session, buf, len) : srtp_unprotect(srtp->session, buf, len);
                if (res != err_status_no_ctx) {
@@ -302,8 +342,57 @@ static int ast_srtp_unprotect(struct ast_srtp *srtp, void *buf, int *len, int rt
                }
        }
 
+       if (retry == 0  && res == err_status_replay_old) {
+               ast_log(LOG_WARNING, "SRTP unprotect: %s\n", srtp_errstr(res));
+
+               if (srtp->session) {
+                       struct ast_srtp_policy *policy;
+                       struct ao2_iterator it;
+                       int policies_count = 0;
+                       
+                       // dealloc first
+                       ast_log(LOG_WARNING, "SRTP destroy before re-create\n");
+                       srtp_dealloc(srtp->session);
+                       
+                       // get the count
+                       policies_count = ao2_container_count(srtp->policies);
+                       
+                       // get the first to build up
+                       it = ao2_iterator_init(srtp->policies, 0);
+                       policy = ao2_iterator_next(&it);
+
+                       ast_log(LOG_WARNING, "SRTP try to re-create\n");
+                       if (srtp_create(&srtp->session, &policy->sp) == err_status_ok) {
+                               ast_log(LOG_WARNING, "SRTP re-created with first policy\n");
+                               
+                               // unref first element
+                               ao2_t_ref(policy, -1, "Unreffing first policy for re-creating srtp session");
+                               
+                               // if we have more than one policy, add them afterwards 
+                               if (policies_count > 1) {
+                                       ast_log(LOG_WARNING, "Add all the other %d policies\n", policies_count-1);
+                                       while ((policy = ao2_iterator_next(&it))) {
+                                               srtp_add_stream(srtp->session, &policy->sp);
+                                               ao2_t_ref(policy, -1, "Unreffing n-th policy for re-creating srtp session");
+                                       }
+                               }
+                               
+                               retry++;
+                               ao2_iterator_destroy(&it);
+                               goto tryagain;
+                       }
+                       ao2_iterator_destroy(&it);
+               }
+       }
+
        if (res != err_status_ok && res != err_status_replay_fail ) {
-               ast_debug(1, "SRTP unprotect: %s\n", srtp_errstr(res));
+               if ((srtp->warned >= 10) && !((srtp->warned - 10) % 100)) {
+                       ast_log(LOG_WARNING, "SRTP unprotect: %s %d\n", srtp_errstr(res), srtp->warned);
+                       srtp->warned = 11;
+               } else {
+                       srtp->warned++;
+               }
+               errno = EAGAIN;
                return -1;
        }
 
@@ -313,19 +402,22 @@ static int ast_srtp_unprotect(struct ast_srtp *srtp, void *buf, int *len, int rt
 static int ast_srtp_protect(struct ast_srtp *srtp, void **buf, int *len, int rtcp)
 {
        int res;
+       unsigned char *localbuf;
 
        if ((*len + SRTP_MAX_TRAILER_LEN) > sizeof(srtp->buf)) {
                return -1;
        }
+       
+       localbuf = rtcp ? srtp->rtcpbuf : srtp->buf;
 
-       memcpy(srtp->buf, *buf, *len);
+       memcpy(localbuf, *buf, *len);
 
-       if ((res = rtcp ? srtp_protect_rtcp(srtp->session, srtp->buf, len) : srtp_protect(srtp->session, srtp->buf, len)) != err_status_ok && res != err_status_replay_fail) {
-               ast_debug(1, "SRTP protect: %s\n", srtp_errstr(res));
+       if ((res = rtcp ? srtp_protect_rtcp(srtp->session, localbuf, len) : srtp_protect(srtp->session, localbuf, len)) != err_status_ok && res != err_status_replay_fail) {
+               ast_log(LOG_WARNING, "SRTP protect: %s\n", srtp_errstr(res));
                return -1;
        }
 
-       *buf = srtp->buf;
+       *buf = localbuf;
        return *len;
 }
 
@@ -341,9 +433,12 @@ static int ast_srtp_create(struct ast_srtp **srtp, struct ast_rtp_instance *rtp,
                return -1;
        }
 
+       ast_module_ref(ast_module_info->self);
        temp->rtp = rtp;
        *srtp = temp;
 
+       ao2_t_link((*srtp)->policies, policy, "Created initial policy");
+
        return 0;
 }
 
@@ -353,16 +448,53 @@ static void ast_srtp_destroy(struct ast_srtp *srtp)
                srtp_dealloc(srtp->session);
        }
 
+       ao2_t_callback(srtp->policies, OBJ_UNLINK | OBJ_NODATA | OBJ_MULTIPLE, NULL, NULL, "Unallocate policy");
+       ao2_t_ref(srtp->policies, -1, "Destroying container");
+
        ast_free(srtp);
+       ast_module_unref(ast_module_info->self);
 }
 
 static int ast_srtp_add_stream(struct ast_srtp *srtp, struct ast_srtp_policy *policy)
 {
-       if (!srtp->has_stream && srtp_add_stream(srtp->session, &policy->sp) != err_status_ok) {
+       struct ast_srtp_policy *match;
+
+       if ((match = find_policy(srtp, &policy->sp, OBJ_POINTER))) {
+               ast_debug(3, "Policy already exists, not re-adding\n");
+               ao2_t_ref(match, -1, "Unreffing already existing policy");
+               return -1;
+       }
+
+       if (srtp_add_stream(srtp->session, &policy->sp) != err_status_ok) {
                return -1;
        }
 
-       srtp->has_stream = 1;
+       ao2_t_link(srtp->policies, policy, "Added additional stream");
+
+       return 0;
+}
+
+static int ast_srtp_change_source(struct ast_srtp *srtp, unsigned int from_ssrc, unsigned int to_ssrc)
+{
+       struct ast_srtp_policy *match;
+       struct srtp_policy_t sp = {
+               .ssrc.type = ssrc_specific,
+               .ssrc.value = from_ssrc,
+       };
+       err_status_t status;
+
+       /* If we find a mach, return and unlink it from the container so we
+        * can change the SSRC (which is part of the hash) and then have
+        * ast_srtp_add_stream link it back in if all is well */
+       if ((match = find_policy(srtp, &sp, OBJ_POINTER | OBJ_UNLINK))) {
+               match->sp.ssrc.value = to_ssrc;
+               if (ast_srtp_add_stream(srtp, match)) {
+                       ast_log(LOG_WARNING, "Couldn't add stream\n");
+               } else if ((status = srtp_remove_stream(srtp->session, from_ssrc))) {
+                       ast_debug(3, "Couldn't remove stream (%d)\n", status);
+               }
+               ao2_t_ref(match, -1, "Unreffing found policy in change_source");
+       }
 
        return 0;
 }