res_pjsip_transport_websocket: Fix crash when the Contact header is not a URI.
[asterisk/asterisk.git] / res / res_pjsip_transport_websocket.c
index f36ede3..7de65dd 100644 (file)
@@ -90,18 +90,17 @@ static pj_status_t ws_destroy(pjsip_transport *transport)
 
        pjsip_endpt_release_pool(wstransport->transport.endpt, wstransport->transport.pool);
 
+       if (wstransport->rdata.tp_info.pool) {
+               pjsip_endpt_release_pool(wstransport->transport.endpt, wstransport->rdata.tp_info.pool);
+       }
+
        return PJ_SUCCESS;
 }
 
 static int transport_shutdown(void *data)
 {
-       RAII_VAR(struct ast_sip_contact_transport *, ct, NULL, ao2_cleanup);
        pjsip_transport *transport = data;
 
-       if ((ct = ast_sip_location_retrieve_contact_transport_by_transport(transport))) {
-               ast_sip_location_delete_contact_transport(ct);
-       }
-
        pjsip_transport_shutdown(transport);
        return 0;
 }
@@ -167,6 +166,15 @@ static int transport_create(void *data)
 
        pjsip_transport_register(newtransport->transport.tpmgr, (pjsip_transport *)newtransport);
 
+       newtransport->rdata.tp_info.transport = &newtransport->transport;
+       newtransport->rdata.tp_info.pool = pjsip_endpt_create_pool(endpt, "rtd%p",
+               PJSIP_POOL_RDATA_LEN, PJSIP_POOL_RDATA_INC);
+       if (!newtransport->rdata.tp_info.pool) {
+               ast_log(LOG_ERROR, "Failed to allocate WebSocket rdata.\n");
+               pjsip_endpt_release_pool(endpt, pool);
+               return -1;
+       }
+
        create_data->transport = newtransport;
        return 0;
 }
@@ -190,9 +198,6 @@ static int transport_read(void *data)
        int recvd;
        pj_str_t buf;
 
-       rdata->tp_info.pool = newtransport->transport.pool;
-       rdata->tp_info.transport = &newtransport->transport;
-
        pj_gettimeofday(&rdata->pkt_info.timestamp);
 
        pj_memcpy(rdata->pkt_info.packet, read_data->payload, sizeof(rdata->pkt_info.packet));
@@ -209,9 +214,42 @@ static int transport_read(void *data)
 
        recvd = pjsip_tpmgr_receive_packet(rdata->tp_info.transport->tpmgr, rdata);
 
+       pj_pool_reset(rdata->tp_info.pool);
+
        return (read_data->payload_len == recvd) ? 0 : -1;
 }
 
+static int get_write_timeout(void)
+{
+       int write_timeout = -1;
+       struct ao2_container *transports;
+
+       transports = ast_sorcery_retrieve_by_fields(ast_sip_get_sorcery(), "transport", AST_RETRIEVE_FLAG_ALL, NULL);
+
+       if (transports) {
+               struct ao2_iterator it_transports = ao2_iterator_init(transports, 0);
+               struct ast_sip_transport *transport;
+
+               for (; (transport = ao2_iterator_next(&it_transports)); ao2_cleanup(transport)) {
+                       if (transport->type != AST_TRANSPORT_WS && transport->type != AST_TRANSPORT_WSS) {
+                               continue;
+                       }
+                       ast_debug(5, "Found %s transport with write timeout: %d\n",
+                               transport->type == AST_TRANSPORT_WS ? "WS" : "WSS",
+                               transport->write_timeout);
+                       write_timeout = MAX(write_timeout, transport->write_timeout);
+               }
+               ao2_cleanup(transports);
+       }
+
+       if (write_timeout < 0) {
+               write_timeout = AST_DEFAULT_WEBSOCKET_WRITE_TIMEOUT;
+       }
+
+       ast_debug(1, "Write timeout for WS/WSS transports: %d\n", write_timeout);
+       return write_timeout;
+}
+
 /*!
  \brief WebSocket connection handler.
  */
@@ -220,12 +258,18 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
        struct ast_taskprocessor *serializer = NULL;
        struct transport_create_data create_data;
        struct ws_transport *transport = NULL;
+       struct transport_read_data read_data;
 
        if (ast_websocket_set_nonblock(session)) {
                ast_websocket_unref(session);
                return;
        }
 
+       if (ast_websocket_set_timeout(session, get_write_timeout())) {
+               ast_websocket_unref(session);
+               return;
+       }
+
        if (!(serializer = ast_sip_create_serializer())) {
                ast_websocket_unref(session);
                return;
@@ -240,9 +284,9 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
        }
 
        transport = create_data.transport;
+       read_data.transport = transport;
 
        while (ast_wait_for_input(ast_websocket_fd(session), -1) > 0) {
-               struct transport_read_data read_data;
                enum ast_websocket_opcode opcode;
                int fragmented;
 
@@ -251,9 +295,7 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
                }
 
                if (opcode == AST_WEBSOCKET_OPCODE_TEXT || opcode == AST_WEBSOCKET_OPCODE_BINARY) {
-                       read_data.transport = transport;
-
-                       ast_sip_push_task(serializer, transport_read, &read_data);
+                       ast_sip_push_task_synchronous(serializer, transport_read, &read_data);
                } else if (opcode == AST_WEBSOCKET_OPCODE_CLOSE) {
                        break;
                }
@@ -266,72 +308,13 @@ static void websocket_cb(struct ast_websocket *session, struct ast_variable *par
 }
 
 /*!
- * \brief Session supplement handler for avoiding DNS lookup on bogus address.
- */
-static void websocket_outgoing_request(struct ast_sip_session *session, struct pjsip_tx_data *tdata)
-{
-       char contact_uri[PJSIP_MAX_URL_SIZE] = { 0, };
-       RAII_VAR(struct ast_sip_contact_transport *, ct, NULL, ao2_cleanup);
-       pjsip_tpselector selector = { .type = PJSIP_TPSELECTOR_TRANSPORT, };
-
-       const pjsip_sip_uri *request_uri = pjsip_uri_get_uri(tdata->msg->line.req.uri);
-
-       if (pj_stricmp2(&request_uri->transport_param, "WS") && pj_stricmp2(&request_uri->transport_param, "WSS")) {
-               return;
-       }
-
-       pjsip_uri_print(PJSIP_URI_IN_REQ_URI, request_uri, contact_uri, sizeof(contact_uri));
-
-       if (!(ct = ast_sip_location_retrieve_contact_transport_by_uri(contact_uri))) {
-               return;
-       }
-
-       selector.u.transport = ct->transport;
-
-       pjsip_tx_data_set_transport(tdata, &selector);
-
-       tdata->dest_info.addr.count = 1;
-       tdata->dest_info.addr.entry[0].type = ct->transport->key.type;
-       tdata->dest_info.addr.entry[0].addr = ct->transport->key.rem_addr;
-       tdata->dest_info.addr.entry[0].addr_len = ct->transport->addr_len;
-}
-
-static struct ast_sip_session_supplement websocket_supplement = {
-       .outgoing_request = websocket_outgoing_request,
-};
-
-/*!
- * \brief Destructor for ast_sip_contact_transport
- */
-static void contact_transport_destroy(void *obj)
-{
-       struct ast_sip_contact_transport *ct = obj;
-
-       ast_string_field_free_memory(ct);
-}
-
-static void *contact_transport_alloc(void)
-{
-       struct ast_sip_contact_transport *ct = ao2_alloc(sizeof(*ct), contact_transport_destroy);
-
-       if (!ct) {
-               return NULL;
-       }
-
-       if (ast_string_field_init(ct, 256)) {
-               ao2_cleanup(ct);
-               return NULL;
-       }
-
-       return ct;
-}
-
-/*!
  * \brief Store the transport a message came in on, so it can be used for outbound messages to that contact.
  */
 static pj_bool_t websocket_on_rx_msg(pjsip_rx_data *rdata)
 {
-       pjsip_contact_hdr *contact_hdr = NULL;
+       static const pj_str_t STR_WS = { "ws", 2 };
+       static const pj_str_t STR_WSS = { "wss", 3 };
+       pjsip_contact_hdr *contact;
 
        long type = rdata->tp_info.transport->key.type;
 
@@ -339,24 +322,17 @@ static pj_bool_t websocket_on_rx_msg(pjsip_rx_data *rdata)
                return PJ_FALSE;
        }
 
-       if ((contact_hdr = pjsip_msg_find_hdr(rdata->msg_info.msg, PJSIP_H_CONTACT, NULL))) {
-               RAII_VAR(struct ast_sip_contact_transport *, ct, NULL, ao2_cleanup);
-               char contact_uri[PJSIP_MAX_URL_SIZE];
-
-               pjsip_uri_print(PJSIP_URI_IN_CONTACT_HDR, pjsip_uri_get_uri(contact_hdr->uri), contact_uri, sizeof(contact_uri));
-
-               if (!(ct = ast_sip_location_retrieve_contact_transport_by_uri(contact_uri))) {
-                       if (!(ct = contact_transport_alloc())) {
-                               return PJ_FALSE;
-                       }
-
-                       ast_string_field_set(ct, uri, contact_uri);
-                       ct->transport = rdata->tp_info.transport;
+       if ((contact = pjsip_msg_find_hdr(rdata->msg_info.msg, PJSIP_H_CONTACT, NULL)) && !contact->star &&
+               (PJSIP_URI_SCHEME_IS_SIP(contact->uri) || PJSIP_URI_SCHEME_IS_SIPS(contact->uri))) {
+               pjsip_sip_uri *uri = pjsip_uri_get_uri(contact->uri);
 
-                       ast_sip_location_add_contact_transport(ct);
-               }
+               pj_cstr(&uri->host, rdata->pkt_info.src_name);
+               uri->port = rdata->pkt_info.src_port;
+               pj_strdup(rdata->tp_info.pool, &uri->transport_param, (type == (long)transport_type_ws) ? &STR_WS : &STR_WSS);
        }
 
+       rdata->msg_info.via->rport_param = 0;
+
        return PJ_FALSE;
 }
 
@@ -365,6 +341,22 @@ static pjsip_module websocket_module = {
        .id = -1,
        .priority = PJSIP_MOD_PRIORITY_TRANSPORT_LAYER,
        .on_rx_request = websocket_on_rx_msg,
+       .on_rx_response = websocket_on_rx_msg,
+};
+
+/*! \brief Function called when an INVITE goes out */
+static void websocket_outgoing_invite_request(struct ast_sip_session *session, struct pjsip_tx_data *tdata)
+{
+       if (session->inv_session->state == PJSIP_INV_STATE_NULL) {
+               pjsip_dlg_add_usage(session->inv_session->dlg, &websocket_module, NULL);
+       }
+}
+
+/*! \brief Supplement for adding Websocket functionality to dialog */
+static struct ast_sip_session_supplement websocket_supplement = {
+       .method = "INVITE",
+       .priority = AST_SIP_SUPPLEMENT_PRIORITY_FIRST + 1,
+       .outgoing_request = websocket_outgoing_invite_request,
 };
 
 static int load_module(void)
@@ -376,9 +368,13 @@ static int load_module(void)
                return AST_MODULE_LOAD_DECLINE;
        }
 
-       ast_sip_session_register_supplement(&websocket_supplement);
+       if (ast_sip_session_register_supplement(&websocket_supplement)) {
+               ast_sip_unregister_service(&websocket_module);
+               return AST_MODULE_LOAD_DECLINE;
+       }
 
        if (ast_websocket_add_protocol("sip", websocket_cb)) {
+               ast_sip_session_unregister_supplement(&websocket_supplement);
                ast_sip_unregister_service(&websocket_module);
                return AST_MODULE_LOAD_DECLINE;
        }
@@ -396,6 +392,7 @@ static int unload_module(void)
 }
 
 AST_MODULE_INFO(ASTERISK_GPL_KEY, AST_MODFLAG_LOAD_ORDER, "PJSIP WebSocket Transport Support",
+               .support_level = AST_MODULE_SUPPORT_CORE,
                .load = load_module,
                .unload = unload_module,
                .load_pri = AST_MODPRI_APP_DEPEND,