Merge "Fix ast_(v)asprintf() malloc failure usage conditions."
[asterisk/asterisk.git] / res / res_pjsip_transport_websocket.c
index 92e018d..22ec195 100644 (file)
@@ -38,8 +38,8 @@
 #include "asterisk/res_pjsip_session.h"
 #include "asterisk/taskprocessor.h"
 
-static int transport_type_ws;
 static int transport_type_wss;
+static int transport_type_wss_ipv6;
 
 /*!
  * \brief Wrapper for pjsip_transport, for storing the WebSocket session
@@ -63,8 +63,9 @@ static pj_status_t ws_send_msg(pjsip_transport *transport,
                             pjsip_transport_callback callback)
 {
        struct ws_transport *wstransport = (struct ws_transport *)transport;
+       uint64_t len = tdata->buf.cur - tdata->buf.start;
 
-       if (ast_websocket_write(wstransport->ws_session, AST_WEBSOCKET_OPCODE_TEXT, tdata->buf.start, (int)(tdata->buf.cur - tdata->buf.start))) {
+       if (ast_websocket_write(wstransport->ws_session, AST_WEBSOCKET_OPCODE_TEXT, tdata->buf.start, len)) {
                return PJ_EUNKNOWN;
        }
 
@@ -79,6 +80,25 @@ static pj_status_t ws_send_msg(pjsip_transport *transport,
 static pj_status_t ws_destroy(pjsip_transport *transport)
 {
        struct ws_transport *wstransport = (struct ws_transport *)transport;
+       int fd = ast_websocket_fd(wstransport->ws_session);
+
+       if (fd > 0) {
+               ast_websocket_close(wstransport->ws_session, 1000);
+               shutdown(fd, SHUT_RDWR);
+       }
+
+       ao2_ref(wstransport, -1);
+
+       return PJ_SUCCESS;
+}
+
+static void transport_dtor(void *arg)
+{
+       struct ws_transport *wstransport = arg;
+
+       if (wstransport->ws_session) {
+               ast_websocket_unref(wstransport->ws_session);
+       }
 
        if (wstransport->transport.ref_cnt) {
                pj_atomic_destroy(wstransport->transport.ref_cnt);
@@ -88,16 +108,28 @@ static pj_status_t ws_destroy(pjsip_transport *transport)
                pj_lock_destroy(wstransport->transport.lock);
        }
 
-       pjsip_endpt_release_pool(wstransport->transport.endpt, wstransport->transport.pool);
+       if (wstransport->transport.endpt && wstransport->transport.pool) {
+               pjsip_endpt_release_pool(wstransport->transport.endpt, wstransport->transport.pool);
+       }
 
-       return PJ_SUCCESS;
+       if (wstransport->rdata.tp_info.pool) {
+               pjsip_endpt_release_pool(wstransport->transport.endpt, wstransport->rdata.tp_info.pool);
+       }
 }
 
 static int transport_shutdown(void *data)
 {
-       pjsip_transport *transport = data;
+       struct ws_transport *wstransport = data;
+
+       if (!wstransport->transport.is_shutdown && !wstransport->transport.is_destroying) {
+               pjsip_transport_shutdown(&wstransport->transport);
+       }
+
+       /* Note that the destructor calls PJSIP functions,
+        * therefore it must be called in a PJSIP thread.
+        */
+       ao2_ref(wstransport, -1);
 
-       pjsip_transport_shutdown(transport);
        return 0;
 }
 
@@ -112,58 +144,125 @@ struct transport_create_data {
 static int transport_create(void *data)
 {
        struct transport_create_data *create_data = data;
-       struct ws_transport *newtransport;
+       struct ws_transport *newtransport = NULL;
+       pjsip_tp_state_callback state_cb;
 
        pjsip_endpoint *endpt = ast_sip_get_pjsip_endpoint();
        struct pjsip_tpmgr *tpmgr = pjsip_endpt_get_tpmgr(endpt);
 
+       char *ws_addr_str;
        pj_pool_t *pool;
-
        pj_str_t buf;
+       pj_status_t status;
 
-       if (!(pool = pjsip_endpt_create_pool(endpt, "ws", 512, 512))) {
-               ast_log(LOG_ERROR, "Failed to allocate WebSocket endpoint pool.\n");
-               return -1;
+       newtransport = ao2_t_alloc_options(sizeof(*newtransport), transport_dtor,
+                       AO2_ALLOC_OPT_LOCK_NOLOCK, "pjsip websocket transport");
+       if (!newtransport) {
+               ast_log(LOG_ERROR, "Failed to allocate WebSocket transport.\n");
+               goto on_error;
        }
 
-       if (!(newtransport = PJ_POOL_ZALLOC_T(pool, struct ws_transport))) {
-               ast_log(LOG_ERROR, "Failed to allocate WebSocket transport.\n");
-               pjsip_endpt_release_pool(endpt, pool);
-               return -1;
+       /* Give websocket transport a unique name for its lifetime */
+       snprintf(newtransport->transport.obj_name, PJ_MAX_OBJ_NAME, "ws%p",
+               &newtransport->transport);
+
+       newtransport->transport.endpt = endpt;
+
+       if (!(pool = pjsip_endpt_create_pool(endpt, "ws", 512, 512))) {
+               ast_log(LOG_ERROR, "Failed to allocate WebSocket endpoint pool.\n");
+               goto on_error;
        }
 
+       newtransport->transport.pool = pool;
        newtransport->ws_session = create_data->ws_session;
 
-       pj_atomic_create(pool, 0, &newtransport->transport.ref_cnt);
-       pj_lock_create_recursive_mutex(pool, pool->obj_name, &newtransport->transport.lock);
+       /* Keep the session until transport dies */
+       ast_websocket_ref(newtransport->ws_session);
 
-       newtransport->transport.pool = pool;
-       pj_sockaddr_parse(pj_AF_UNSPEC(), 0, pj_cstr(&buf, ast_sockaddr_stringify(ast_websocket_remote_address(newtransport->ws_session))), &newtransport->transport.key.rem_addr);
-       newtransport->transport.key.rem_addr.addr.sa_family = pj_AF_INET();
-       newtransport->transport.key.type = ast_websocket_is_secure(newtransport->ws_session) ? transport_type_wss : transport_type_ws;
+       status = pj_atomic_create(pool, 0, &newtransport->transport.ref_cnt);
+       if (status != PJ_SUCCESS) {
+               goto on_error;
+       }
+
+       status = pj_lock_create_recursive_mutex(pool, pool->obj_name, &newtransport->transport.lock);
+       if (status != PJ_SUCCESS) {
+               goto on_error;
+       }
+
+       /*
+        * The type_name here is mostly used by log messages eihter in
+        * pjproject or Asterisk.  Other places are reconstituting subscriptions
+        * after a restart (which could never work for a websocket connection anyway),
+        * received MESSAGE requests to set PJSIP_TRANSPORT, and most importantly
+        * by pjproject when generating the Via header.
+        */
+       newtransport->transport.type_name = ast_websocket_is_secure(newtransport->ws_session)
+               ? "WSS" : "WS";
+
+       ws_addr_str = ast_sockaddr_stringify(ast_websocket_remote_address(newtransport->ws_session));
+       ast_debug(4, "Creating websocket transport for %s:%s\n",
+               newtransport->transport.type_name, ws_addr_str);
+
+       pj_sockaddr_parse(pj_AF_UNSPEC(), 0, pj_cstr(&buf, ws_addr_str), &newtransport->transport.key.rem_addr);
+       if (newtransport->transport.key.rem_addr.addr.sa_family == pj_AF_INET6()) {
+               newtransport->transport.key.type = transport_type_wss_ipv6;
+               newtransport->transport.local_name.host.ptr = (char *)pj_pool_alloc(pool, PJ_INET6_ADDRSTRLEN);
+               pj_sockaddr_print(&newtransport->transport.key.rem_addr, newtransport->transport.local_name.host.ptr, PJ_INET6_ADDRSTRLEN, 0);
+       } else {
+               newtransport->transport.key.type = transport_type_wss;
+               newtransport->transport.local_name.host.ptr = (char *)pj_pool_alloc(pool, PJ_INET_ADDRSTRLEN);
+               pj_sockaddr_print(&newtransport->transport.key.rem_addr, newtransport->transport.local_name.host.ptr, PJ_INET_ADDRSTRLEN, 0);
+       }
 
        newtransport->transport.addr_len = pj_sockaddr_get_len(&newtransport->transport.key.rem_addr);
 
        pj_sockaddr_cp(&newtransport->transport.local_addr, &newtransport->transport.key.rem_addr);
 
-       newtransport->transport.local_name.host.ptr = (char *)pj_pool_alloc(pool, newtransport->transport.addr_len+4);
-       pj_sockaddr_print(&newtransport->transport.key.rem_addr, newtransport->transport.local_name.host.ptr, newtransport->transport.addr_len+4, 0);
        newtransport->transport.local_name.host.slen = pj_ansi_strlen(newtransport->transport.local_name.host.ptr);
        newtransport->transport.local_name.port = pj_sockaddr_get_port(&newtransport->transport.key.rem_addr);
 
-       newtransport->transport.type_name = (char *)pjsip_transport_get_type_name(newtransport->transport.key.type);
        newtransport->transport.flag = pjsip_transport_get_flag_from_type((pjsip_transport_type_e)newtransport->transport.key.type);
        newtransport->transport.info = (char *)pj_pool_alloc(newtransport->transport.pool, 64);
 
-       newtransport->transport.endpt = endpt;
+       newtransport->transport.dir = PJSIP_TP_DIR_INCOMING;
        newtransport->transport.tpmgr = tpmgr;
        newtransport->transport.send_msg = &ws_send_msg;
        newtransport->transport.destroy = &ws_destroy;
 
-       pjsip_transport_register(newtransport->transport.tpmgr, (pjsip_transport *)newtransport);
+       status = pjsip_transport_register(newtransport->transport.tpmgr,
+                       (pjsip_transport *)newtransport);
+       if (status != PJ_SUCCESS) {
+               goto on_error;
+       }
+
+       /* Add a reference for pjsip transport manager */
+       ao2_ref(newtransport, +1);
+
+       newtransport->rdata.tp_info.transport = &newtransport->transport;
+       newtransport->rdata.tp_info.pool = pjsip_endpt_create_pool(endpt, "rtd%p",
+               PJSIP_POOL_RDATA_LEN, PJSIP_POOL_RDATA_INC);
+       if (!newtransport->rdata.tp_info.pool) {
+               ast_log(LOG_ERROR, "Failed to allocate WebSocket rdata.\n");
+               pjsip_transport_destroy((pjsip_transport *)newtransport);
+               goto on_error;
+       }
 
        create_data->transport = newtransport;
+
+       /* Notify application of transport state */
+       state_cb = pjsip_tpmgr_get_state_cb(newtransport->transport.tpmgr);
+       if (state_cb) {
+               pjsip_transport_state_info state_info;
+
+               memset(&state_info, 0, sizeof(state_info));
+               state_cb(&newtransport->transport, PJSIP_TP_STATE_CONNECTED, &state_info);
+       }
+
        return 0;
+
+on_error:
+       ao2_cleanup(newtransport);
+       return -1;
 }
 
 struct transport_read_data {
@@ -184,19 +283,16 @@ static int transport_read(void *data)
        pjsip_rx_data *rdata = &newtransport->rdata;
        int recvd;
        pj_str_t buf;
-
-       rdata->tp_info.pool = newtransport->transport.pool;
-       rdata->tp_info.transport = &newtransport->transport;
+       int pjsip_pkt_len;
 
        pj_gettimeofday(&rdata->pkt_info.timestamp);
 
-       pj_memcpy(rdata->pkt_info.packet, read_data->payload, sizeof(rdata->pkt_info.packet));
-       rdata->pkt_info.len = read_data->payload_len;
+       pjsip_pkt_len = PJSIP_MAX_PKT_LEN < read_data->payload_len ? PJSIP_MAX_PKT_LEN : read_data->payload_len;
+       pj_memcpy(rdata->pkt_info.packet, read_data->payload, pjsip_pkt_len);
+       rdata->pkt_info.len = pjsip_pkt_len;
        rdata->pkt_info.zero = 0;
 
        pj_sockaddr_parse(pj_AF_UNSPEC(), 0, pj_cstr(&buf, ast_sockaddr_stringify(ast_websocket_remote_address(session))), &rdata->pkt_info.src_addr);
-       rdata->pkt_info.src_addr.addr.sa_family = pj_AF_INET();
-
        rdata->pkt_info.src_addr_len = sizeof(rdata->pkt_info.src_addr);
 
        pj_ansi_strcpy(rdata->pkt_info.src_name, ast_sockaddr_stringify_host(ast_websocket_remote_address(session)));
@@ -204,22 +300,30 @@ static int transport_read(void *data)
 
        recvd = pjsip_tpmgr_receive_packet(rdata->tp_info.transport->tpmgr, rdata);
 
+       pj_pool_reset(rdata->tp_info.pool);
+
        return (read_data->payload_len == recvd) ? 0 : -1;
 }
 
 static int get_write_timeout(void)
 {
        int write_timeout = -1;
-       struct ao2_container *transports;
+       struct ao2_container *transport_states;
+
+       transport_states = ast_sip_get_transport_states();
 
-       transports = ast_sorcery_retrieve_by_fields(ast_sip_get_sorcery(), "transport", AST_RETRIEVE_FLAG_ALL, NULL);
+       if (transport_states) {
+               struct ao2_iterator it_transport_states = ao2_iterator_init(transport_states, 0);
+               struct ast_sip_transport_state *transport_state;
 
-       if (transports) {
-               struct ao2_iterator it_transports = ao2_iterator_init(transports, 0);
-               struct ast_sip_transport *transport;
+               for (; (transport_state = ao2_iterator_next(&it_transport_states)); ao2_cleanup(transport_state)) {
+                       struct ast_sip_transport *transport;
 
-               for (; (transport = ao2_iterator_next(&it_transports)); ao2_cleanup(transport)) {
-                       if (transport->type != AST_TRANSPORT_WS && transport->type != AST_TRANSPORT_WSS) {
+                       if (transport_state->type != AST_TRANSPORT_WS && transport_state->type != AST_TRANSPORT_WSS) {
+                               continue;
+                       }
+                       transport = ast_sorcery_retrieve_by_id(ast_sip_get_sorcery(), "transport", transport_state->id);
+                       if (!transport) {
                                continue;
                        }
                        ast_debug(5, "Found %s transport with write timeout: %d\n",
@@ -227,7 +331,8 @@ static int get_write_timeout(void)
                                transport->write_timeout);
                        write_timeout = MAX(write_timeout, transport->write_timeout);
                }
-               ao2_cleanup(transports);
+               ao2_iterator_destroy(&it_transport_states);
+               ao2_cleanup(transport_states);
        }
 
        if (write_timeout < 0) {
@@ -238,14 +343,22 @@ static int get_write_timeout(void)
        return write_timeout;
 }
 
-/*!
- \brief WebSocket connection handler.
- */
+static struct ast_taskprocessor *create_websocket_serializer(void)
+{
+       char tps_name[AST_TASKPROCESSOR_MAX_NAME + 1];
+
+       /* Create name with seq number appended. */
+       ast_taskprocessor_build_name(tps_name, sizeof(tps_name), "pjsip/websocket");
+
+       return ast_sip_create_serializer(tps_name);
+}
+
+/*! \brief WebSocket connection handler. */
 static void websocket_cb(struct ast_websocket *session, struct ast_variable *parameters, struct ast_variable *headers)
 {
-       struct ast_taskprocessor *serializer = NULL;
+       struct ast_taskprocessor *serializer;
        struct transport_create_data create_data;
-       struct ws_transport *transport = NULL;
+       struct ws_transport *transport;
        struct transport_read_data read_data;
 
        if (ast_websocket_set_nonblock(session)) {
@@ -258,7 +371,8 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
                return;
        }
 
-       if (!(serializer = ast_sip_create_serializer())) {
+       serializer = create_websocket_serializer();
+       if (!serializer) {
                ast_websocket_unref(session);
                return;
        }
@@ -267,6 +381,7 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
 
        if (ast_sip_push_task_synchronous(serializer, transport_create, &create_data)) {
                ast_log(LOG_ERROR, "Could not create WebSocket transport.\n");
+               ast_taskprocessor_unreference(serializer);
                ast_websocket_unref(session);
                return;
        }
@@ -301,22 +416,31 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
 static pj_bool_t websocket_on_rx_msg(pjsip_rx_data *rdata)
 {
        static const pj_str_t STR_WS = { "ws", 2 };
-       static const pj_str_t STR_WSS = { "wss", 3 };
        pjsip_contact_hdr *contact;
 
        long type = rdata->tp_info.transport->key.type;
 
-       if (type != (long)transport_type_ws && type != (long)transport_type_wss) {
+       if (type != (long) transport_type_wss && type != (long) transport_type_wss_ipv6) {
                return PJ_FALSE;
        }
 
-       if ((contact = pjsip_msg_find_hdr(rdata->msg_info.msg, PJSIP_H_CONTACT, NULL)) &&
-               (PJSIP_URI_SCHEME_IS_SIP(contact->uri) || PJSIP_URI_SCHEME_IS_SIPS(contact->uri))) {
+       contact = pjsip_msg_find_hdr(rdata->msg_info.msg, PJSIP_H_CONTACT, NULL);
+       if (contact
+               && !contact->star
+               && (PJSIP_URI_SCHEME_IS_SIP(contact->uri) || PJSIP_URI_SCHEME_IS_SIPS(contact->uri))) {
                pjsip_sip_uri *uri = pjsip_uri_get_uri(contact->uri);
+               const pj_str_t *txp_str = &STR_WS;
+
+               ast_debug(4, "%s re-writing Contact URI from %.*s:%d%s%.*s to %s:%d;transport=%s\n",
+                       pjsip_rx_data_get_info(rdata),
+                       (int)pj_strlen(&uri->host), pj_strbuf(&uri->host), uri->port,
+                       pj_strlen(&uri->transport_param) ? ";transport=" : "",
+                       (int)pj_strlen(&uri->transport_param), pj_strbuf(&uri->transport_param),
+                       rdata->pkt_info.src_name ?: "", rdata->pkt_info.src_port, pj_strbuf(txp_str));
 
                pj_cstr(&uri->host, rdata->pkt_info.src_name);
                uri->port = rdata->pkt_info.src_port;
-               pj_strdup(rdata->tp_info.pool, &uri->transport_param, (type == (long)transport_type_ws) ? &STR_WS : &STR_WSS);
+               pj_strdup(rdata->tp_info.pool, &uri->transport_param, txp_str);
        }
 
        rdata->msg_info.via->rport_param = 0;
@@ -329,18 +453,52 @@ static pjsip_module websocket_module = {
        .id = -1,
        .priority = PJSIP_MOD_PRIORITY_TRANSPORT_LAYER,
        .on_rx_request = websocket_on_rx_msg,
+       .on_rx_response = websocket_on_rx_msg,
+};
+
+/*! \brief Function called when an INVITE goes out */
+static void websocket_outgoing_invite_request(struct ast_sip_session *session, struct pjsip_tx_data *tdata)
+{
+       if (session->inv_session->state == PJSIP_INV_STATE_NULL) {
+               pjsip_dlg_add_usage(session->inv_session->dlg, &websocket_module, NULL);
+       }
+}
+
+/*! \brief Supplement for adding Websocket functionality to dialog */
+static struct ast_sip_session_supplement websocket_supplement = {
+       .method = "INVITE",
+       .priority = AST_SIP_SUPPLEMENT_PRIORITY_FIRST + 1,
+       .outgoing_request = websocket_outgoing_invite_request,
 };
 
 static int load_module(void)
 {
-       pjsip_transport_register_type(PJSIP_TRANSPORT_RELIABLE, "WS", 5060, &transport_type_ws);
-       pjsip_transport_register_type(PJSIP_TRANSPORT_RELIABLE, "WSS", 5060, &transport_type_wss);
+       CHECK_PJSIP_MODULE_LOADED();
+
+       /*
+        * We only need one transport type name (ws) defined.  Firefox
+        * and Chrome do not support anything other than secure websockets
+        * anymore.
+        *
+        * Also we really cannot have two transports with the same name
+        * and address family because it would be ambiguous.  Outgoing
+        * requests may try to find the transport by name and pjproject
+        * only finds the first one registered.
+        */
+       pjsip_transport_register_type(PJSIP_TRANSPORT_RELIABLE | PJSIP_TRANSPORT_SECURE, "ws", 5060, &transport_type_wss);
+       pjsip_transport_register_type(PJSIP_TRANSPORT_RELIABLE | PJSIP_TRANSPORT_SECURE | PJSIP_TRANSPORT_IPV6, "ws", 5060, &transport_type_wss_ipv6);
 
        if (ast_sip_register_service(&websocket_module) != PJ_SUCCESS) {
                return AST_MODULE_LOAD_DECLINE;
        }
 
+       if (ast_sip_session_register_supplement(&websocket_supplement)) {
+               ast_sip_unregister_service(&websocket_module);
+               return AST_MODULE_LOAD_DECLINE;
+       }
+
        if (ast_websocket_add_protocol("sip", websocket_cb)) {
+               ast_sip_session_unregister_supplement(&websocket_supplement);
                ast_sip_unregister_service(&websocket_module);
                return AST_MODULE_LOAD_DECLINE;
        }
@@ -351,14 +509,15 @@ static int load_module(void)
 static int unload_module(void)
 {
        ast_sip_unregister_service(&websocket_module);
+       ast_sip_session_unregister_supplement(&websocket_supplement);
        ast_websocket_remove_protocol("sip", websocket_cb);
 
        return 0;
 }
 
 AST_MODULE_INFO(ASTERISK_GPL_KEY, AST_MODFLAG_LOAD_ORDER, "PJSIP WebSocket Transport Support",
-               .support_level = AST_MODULE_SUPPORT_CORE,
-               .load = load_module,
-               .unload = unload_module,
-               .load_pri = AST_MODPRI_APP_DEPEND,
-          );
+       .support_level = AST_MODULE_SUPPORT_CORE,
+       .load = load_module,
+       .unload = unload_module,
+       .load_pri = AST_MODPRI_APP_DEPEND,
+);