res_pjsip_transport_websocket: Fix crash when the Contact header is not a URI.
[asterisk/asterisk.git] / res / res_pjsip_transport_websocket.c
index bae120a..7de65dd 100644 (file)
@@ -90,6 +90,10 @@ static pj_status_t ws_destroy(pjsip_transport *transport)
 
        pjsip_endpt_release_pool(wstransport->transport.endpt, wstransport->transport.pool);
 
+       if (wstransport->rdata.tp_info.pool) {
+               pjsip_endpt_release_pool(wstransport->transport.endpt, wstransport->rdata.tp_info.pool);
+       }
+
        return PJ_SUCCESS;
 }
 
@@ -162,6 +166,15 @@ static int transport_create(void *data)
 
        pjsip_transport_register(newtransport->transport.tpmgr, (pjsip_transport *)newtransport);
 
+       newtransport->rdata.tp_info.transport = &newtransport->transport;
+       newtransport->rdata.tp_info.pool = pjsip_endpt_create_pool(endpt, "rtd%p",
+               PJSIP_POOL_RDATA_LEN, PJSIP_POOL_RDATA_INC);
+       if (!newtransport->rdata.tp_info.pool) {
+               ast_log(LOG_ERROR, "Failed to allocate WebSocket rdata.\n");
+               pjsip_endpt_release_pool(endpt, pool);
+               return -1;
+       }
+
        create_data->transport = newtransport;
        return 0;
 }
@@ -185,9 +198,6 @@ static int transport_read(void *data)
        int recvd;
        pj_str_t buf;
 
-       rdata->tp_info.pool = newtransport->transport.pool;
-       rdata->tp_info.transport = &newtransport->transport;
-
        pj_gettimeofday(&rdata->pkt_info.timestamp);
 
        pj_memcpy(rdata->pkt_info.packet, read_data->payload, sizeof(rdata->pkt_info.packet));
@@ -204,6 +214,8 @@ static int transport_read(void *data)
 
        recvd = pjsip_tpmgr_receive_packet(rdata->tp_info.transport->tpmgr, rdata);
 
+       pj_pool_reset(rdata->tp_info.pool);
+
        return (read_data->payload_len == recvd) ? 0 : -1;
 }
 
@@ -300,6 +312,8 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
  */
 static pj_bool_t websocket_on_rx_msg(pjsip_rx_data *rdata)
 {
+       static const pj_str_t STR_WS = { "ws", 2 };
+       static const pj_str_t STR_WSS = { "wss", 3 };
        pjsip_contact_hdr *contact;
 
        long type = rdata->tp_info.transport->key.type;
@@ -308,12 +322,13 @@ static pj_bool_t websocket_on_rx_msg(pjsip_rx_data *rdata)
                return PJ_FALSE;
        }
 
-       if ((contact = pjsip_msg_find_hdr(rdata->msg_info.msg, PJSIP_H_CONTACT, NULL)) &&
+       if ((contact = pjsip_msg_find_hdr(rdata->msg_info.msg, PJSIP_H_CONTACT, NULL)) && !contact->star &&
                (PJSIP_URI_SCHEME_IS_SIP(contact->uri) || PJSIP_URI_SCHEME_IS_SIPS(contact->uri))) {
                pjsip_sip_uri *uri = pjsip_uri_get_uri(contact->uri);
 
                pj_cstr(&uri->host, rdata->pkt_info.src_name);
                uri->port = rdata->pkt_info.src_port;
+               pj_strdup(rdata->tp_info.pool, &uri->transport_param, (type == (long)transport_type_ws) ? &STR_WS : &STR_WSS);
        }
 
        rdata->msg_info.via->rport_param = 0;
@@ -326,6 +341,22 @@ static pjsip_module websocket_module = {
        .id = -1,
        .priority = PJSIP_MOD_PRIORITY_TRANSPORT_LAYER,
        .on_rx_request = websocket_on_rx_msg,
+       .on_rx_response = websocket_on_rx_msg,
+};
+
+/*! \brief Function called when an INVITE goes out */
+static void websocket_outgoing_invite_request(struct ast_sip_session *session, struct pjsip_tx_data *tdata)
+{
+       if (session->inv_session->state == PJSIP_INV_STATE_NULL) {
+               pjsip_dlg_add_usage(session->inv_session->dlg, &websocket_module, NULL);
+       }
+}
+
+/*! \brief Supplement for adding Websocket functionality to dialog */
+static struct ast_sip_session_supplement websocket_supplement = {
+       .method = "INVITE",
+       .priority = AST_SIP_SUPPLEMENT_PRIORITY_FIRST + 1,
+       .outgoing_request = websocket_outgoing_invite_request,
 };
 
 static int load_module(void)
@@ -337,7 +368,13 @@ static int load_module(void)
                return AST_MODULE_LOAD_DECLINE;
        }
 
+       if (ast_sip_session_register_supplement(&websocket_supplement)) {
+               ast_sip_unregister_service(&websocket_module);
+               return AST_MODULE_LOAD_DECLINE;
+       }
+
        if (ast_websocket_add_protocol("sip", websocket_cb)) {
+               ast_sip_session_unregister_supplement(&websocket_supplement);
                ast_sip_unregister_service(&websocket_module);
                return AST_MODULE_LOAD_DECLINE;
        }
@@ -348,12 +385,14 @@ static int load_module(void)
 static int unload_module(void)
 {
        ast_sip_unregister_service(&websocket_module);
+       ast_sip_session_unregister_supplement(&websocket_supplement);
        ast_websocket_remove_protocol("sip", websocket_cb);
 
        return 0;
 }
 
 AST_MODULE_INFO(ASTERISK_GPL_KEY, AST_MODFLAG_LOAD_ORDER, "PJSIP WebSocket Transport Support",
+               .support_level = AST_MODULE_SUPPORT_CORE,
                .load = load_module,
                .unload = unload_module,
                .load_pri = AST_MODPRI_APP_DEPEND,