Fix segfault for certain invalid WebSocket input.
[asterisk/asterisk.git] / res / res_http_websocket.c
index cfc613c..077c212 100644 (file)
@@ -77,9 +77,6 @@ struct websocket_protocol {
        ast_websocket_callback callback; /*!< Callback called when a new session is established */
 };
 
-/*! \brief Container for registered protocols */
-static struct ao2_container *protocols;
-
 /*! \brief Hashing function for protocols */
 static int protocol_hash_fn(const void *obj, const int flags)
 {
@@ -105,6 +102,36 @@ static void protocol_destroy_fn(void *obj)
        ast_free(protocol->name);
 }
 
+/*! \brief Structure for a WebSocket server */
+struct ast_websocket_server {
+       struct ao2_container *protocols; /*!< Container for registered protocols */
+};
+
+static void websocket_server_dtor(void *obj)
+{
+       struct ast_websocket_server *server = obj;
+       ao2_cleanup(server->protocols);
+       server->protocols = NULL;
+}
+
+struct ast_websocket_server *ast_websocket_server_create(void)
+{
+       RAII_VAR(struct ast_websocket_server *, server, NULL, ao2_cleanup);
+
+       server = ao2_alloc(sizeof(*server), websocket_server_dtor);
+       if (!server) {
+               return NULL;
+       }
+
+       server->protocols = ao2_container_alloc(MAX_PROTOCOL_BUCKETS, protocol_hash_fn, protocol_cmp_fn);
+       if (!server->protocols) {
+               return NULL;
+       }
+
+       ao2_ref(server, +1);
+       return server;
+}
+
 /*! \brief Destructor function for sessions */
 static void session_destroy_fn(void *obj)
 {
@@ -118,34 +145,38 @@ static void session_destroy_fn(void *obj)
        ast_free(session->payload);
 }
 
-int AST_OPTIONAL_API_NAME(ast_websocket_add_protocol)(const char *name, ast_websocket_callback callback)
+int AST_OPTIONAL_API_NAME(ast_websocket_server_add_protocol)(struct ast_websocket_server *server, const char *name, ast_websocket_callback callback)
 {
        struct websocket_protocol *protocol;
 
-       ao2_lock(protocols);
+       if (!server->protocols) {
+               return -1;
+       }
+
+       ao2_lock(server->protocols);
 
        /* Ensure a second protocol handler is not registered for the same protocol */
-       if ((protocol = ao2_find(protocols, name, OBJ_KEY | OBJ_NOLOCK))) {
+       if ((protocol = ao2_find(server->protocols, name, OBJ_KEY | OBJ_NOLOCK))) {
                ao2_ref(protocol, -1);
-               ao2_unlock(protocols);
+               ao2_unlock(server->protocols);
                return -1;
        }
 
        if (!(protocol = ao2_alloc(sizeof(*protocol), protocol_destroy_fn))) {
-               ao2_unlock(protocols);
+               ao2_unlock(server->protocols);
                return -1;
        }
 
        if (!(protocol->name = ast_strdup(name))) {
                ao2_ref(protocol, -1);
-               ao2_unlock(protocols);
+               ao2_unlock(server->protocols);
                return -1;
        }
 
        protocol->callback = callback;
 
-       ao2_link_flags(protocols, protocol, OBJ_NOLOCK);
-       ao2_unlock(protocols);
+       ao2_link_flags(server->protocols, protocol, OBJ_NOLOCK);
+       ao2_unlock(server->protocols);
        ao2_ref(protocol, -1);
 
        ast_verb(2, "WebSocket registered sub-protocol '%s'\n", name);
@@ -153,11 +184,11 @@ int AST_OPTIONAL_API_NAME(ast_websocket_add_protocol)(const char *name, ast_webs
        return 0;
 }
 
-int AST_OPTIONAL_API_NAME(ast_websocket_remove_protocol)(const char *name, ast_websocket_callback callback)
+int AST_OPTIONAL_API_NAME(ast_websocket_server_remove_protocol)(struct ast_websocket_server *server, const char *name, ast_websocket_callback callback)
 {
        struct websocket_protocol *protocol;
 
-       if (!(protocol = ao2_find(protocols, name, OBJ_KEY))) {
+       if (!(protocol = ao2_find(server->protocols, name, OBJ_KEY))) {
                return -1;
        }
 
@@ -166,7 +197,7 @@ int AST_OPTIONAL_API_NAME(ast_websocket_remove_protocol)(const char *name, ast_w
                return -1;
        }
 
-       ao2_unlink(protocols, protocol);
+       ao2_unlink(server->protocols, protocol);
        ao2_ref(protocol, -1);
 
        ast_verb(2, "WebSocket unregistered sub-protocol '%s'\n", name);
@@ -466,14 +497,14 @@ int AST_OPTIONAL_API_NAME(ast_websocket_read)(struct ast_websocket *session, cha
        return 0;
 }
 
-/*! \brief Callback that is executed everytime an HTTP request is received by this module */
-static int websocket_callback(struct ast_tcptls_session_instance *ser, const struct ast_http_uri *urih, const char *uri, enum ast_http_method method, struct ast_variable *get_vars, struct ast_variable *headers)
+int ast_websocket_uri_cb(struct ast_tcptls_session_instance *ser, const struct ast_http_uri *urih, const char *uri, enum ast_http_method method, struct ast_variable *get_vars, struct ast_variable *headers)
 {
        struct ast_variable *v;
        char *upgrade = NULL, *key = NULL, *key1 = NULL, *key2 = NULL, *protos = NULL, *requested_protocols = NULL, *protocol = NULL;
        int version = 0, flags = 1;
        struct websocket_protocol *protocol_handler = NULL;
        struct ast_websocket *session;
+       struct ast_websocket_server *server;
 
        /* Upgrade requests are only permitted on GET methods */
        if (method != AST_HTTP_GET) {
@@ -481,6 +512,8 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
                return -1;
        }
 
+       server = urih->data;
+
        /* Get the minimum headers required to satisfy our needs */
        for (v = headers; v; v = v->next) {
                if (!strcasecmp(v->name, "Upgrade")) {
@@ -503,12 +536,12 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
 
        /* If this is not a websocket upgrade abort */
        if (!upgrade || strcasecmp(upgrade, "websocket")) {
-               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - did not request WebSocket",
+               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - did not request WebSocket\n",
                        ast_sockaddr_stringify(&ser->remote_address));
                ast_http_error(ser, 426, "Upgrade Required", NULL);
                return -1;
        } else if (ast_strlen_zero(requested_protocols)) {
-               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - no protocols requested",
+               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - no protocols requested\n",
                        ast_sockaddr_stringify(&ser->remote_address));
                fputs("HTTP/1.1 400 Bad Request\r\n"
                      "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n", ser->f);
@@ -516,7 +549,7 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
        } else if (key1 && key2) {
                /* Specification defined in http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76 and
                 * http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 -- not currently supported*/
-               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - unsupported version '00/76' chosen",
+               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - unsupported version '00/76' chosen\n",
                        ast_sockaddr_stringify(&ser->remote_address));
                fputs("HTTP/1.1 400 Bad Request\r\n"
                      "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n", ser->f);
@@ -525,7 +558,7 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
 
        /* Iterate through the requested protocols trying to find one that we have a handler for */
        while ((protocol = strsep(&requested_protocols, ","))) {
-               if ((protocol_handler = ao2_find(protocols, ast_strip(protocol), OBJ_KEY))) {
+               if ((protocol_handler = ao2_find(server->protocols, ast_strip(protocol), OBJ_KEY))) {
                        break;
                }
        }
@@ -544,11 +577,20 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
                /* Version 7 defined in specification http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07 */
                /* Version 8 defined in specification http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 */
                /* Version 13 defined in specification http://tools.ietf.org/html/rfc6455 */
-               char combined[strlen(key) + strlen(WEBSOCKET_GUID) + 1], base64[64];
+               char *combined, base64[64];
+               unsigned combined_length;
                uint8_t sha[20];
 
+               combined_length = (key ? strlen(key) : 0) + strlen(WEBSOCKET_GUID) + 1;
+               if (!key || combined_length > 8192) { /* no stack overflows please */
+                       fputs("HTTP/1.1 400 Bad Request\r\n"
+                             "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n", ser->f);
+                       ao2_ref(protocol_handler, -1);
+                       return 0;
+               }
+
                if (!(session = ao2_alloc(sizeof(*session), session_destroy_fn))) {
-                       ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted",
+                       ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted\n",
                                ast_sockaddr_stringify(&ser->remote_address));
                        fputs("HTTP/1.1 400 Bad Request\r\n"
                              "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n", ser->f);
@@ -556,7 +598,8 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
                        return 0;
                }
 
-               snprintf(combined, sizeof(combined), "%s%s", key, WEBSOCKET_GUID);
+               combined = ast_alloca(combined_length);
+               snprintf(combined, combined_length, "%s%s", key, WEBSOCKET_GUID);
                ast_sha1_hash_uint(sha, combined);
                ast_base64encode(base64, (const unsigned char*)sha, 20, sizeof(base64));
 
@@ -571,7 +614,7 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
        } else {
 
                /* Specification defined in http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-75 or completely unknown */
-               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - unsupported version '%d' chosen",
+               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - unsupported version '%d' chosen\n",
                        ast_sockaddr_stringify(&ser->remote_address), version ? version : 75);
                fputs("HTTP/1.1 400 Bad Request\r\n"
                      "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n", ser->f);
@@ -581,7 +624,7 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
 
        /* Enable keepalive on all sessions so the underlying user does not have to */
        if (setsockopt(ser->fd, SOL_SOCKET, SO_KEEPALIVE, &flags, sizeof(flags))) {
-               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - failed to enable keepalive",
+               ast_log(LOG_WARNING, "WebSocket connection from '%s' could not be accepted - failed to enable keepalive\n",
                        ast_sockaddr_stringify(&ser->remote_address));
                fputs("HTTP/1.1 400 Bad Request\r\n"
                      "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n", ser->f);
@@ -611,7 +654,7 @@ static int websocket_callback(struct ast_tcptls_session_instance *ser, const str
 }
 
 static struct ast_http_uri websocketuri = {
-       .callback = websocket_callback,
+       .callback = ast_websocket_uri_cb,
        .description = "Asterisk HTTP WebSocket",
        .uri = "ws",
        .has_subtree = 0,
@@ -656,9 +699,30 @@ end:
        ast_websocket_unref(session);
 }
 
+int AST_OPTIONAL_API_NAME(ast_websocket_add_protocol)(const char *name, ast_websocket_callback callback)
+{
+       struct ast_websocket_server *ws_server = websocketuri.data;
+       if (!ws_server) {
+               return -1;
+       }
+       return ast_websocket_server_add_protocol(ws_server, name, callback);
+}
+
+int AST_OPTIONAL_API_NAME(ast_websocket_remove_protocol)(const char *name, ast_websocket_callback callback)
+{
+       struct ast_websocket_server *ws_server = websocketuri.data;
+       if (!ws_server) {
+               return -1;
+       }
+       return ast_websocket_server_remove_protocol(ws_server, name, callback);
+}
+
 static int load_module(void)
 {
-       protocols = ao2_container_alloc(MAX_PROTOCOL_BUCKETS, protocol_hash_fn, protocol_cmp_fn);
+       websocketuri.data = ast_websocket_server_create();
+       if (!websocketuri.data) {
+               return AST_MODULE_LOAD_FAILURE;
+       }
        ast_http_uri_link(&websocketuri);
        ast_websocket_add_protocol("echo", websocket_echo_callback);
 
@@ -669,7 +733,8 @@ static int unload_module(void)
 {
        ast_websocket_remove_protocol("echo", websocket_echo_callback);
        ast_http_uri_unlink(&websocketuri);
-       ao2_ref(protocols, -1);
+       ao2_ref(websocketuri.data, -1);
+       websocketuri.data = NULL;
 
        return 0;
 }