DNS: Fix some corner cases.
[asterisk/asterisk.git] / main / dns_query_set.c
index 45626d1..40a89e1 100644 (file)
 
 #include "asterisk.h"
 
-ASTERISK_FILE_VERSION(__FILE__, "$Revision$")
+ASTERISK_REGISTER_FILE()
 
 #include "asterisk/vector.h"
 #include "asterisk/astobj2.h"
+#include "asterisk/utils.h"
+#include "asterisk/linkedlists.h"
 #include "asterisk/dns_core.h"
 #include "asterisk/dns_query_set.h"
+#include "asterisk/dns_internal.h"
+#include "asterisk/dns_resolver.h"
 
-/*! \brief A set of DNS queries */
-struct ast_dns_query_set {
-       /*! \brief DNS queries */
-       AST_VECTOR(, struct ast_dns_query *) queries;
-       /*! \brief The total number of completed queries */
-       unsigned int queries_completed;
-       /*! \brief Callback to invoke upon completion */
-       ast_dns_query_set_callback callback;
-       /*! \brief User-specific data */
-       void *user_data;
-};
+/*! \brief The default number of expected queries to be added to the query set */
+#define DNS_QUERY_SET_EXPECTED_QUERY_COUNT 5
+
+/*! \brief Destructor for DNS query set */
+static void dns_query_set_destroy(void *data)
+{
+       struct ast_dns_query_set *query_set = data;
+       int idx;
+
+       for (idx = 0; idx < AST_VECTOR_SIZE(&query_set->queries); ++idx) {
+               struct dns_query_set_query *query = AST_VECTOR_GET_ADDR(&query_set->queries, idx);
+
+               ao2_ref(query->query, -1);
+       }
+       AST_VECTOR_FREE(&query_set->queries);
+
+       ao2_cleanup(query_set->user_data);
+}
 
 struct ast_dns_query_set *ast_dns_query_set_create(void)
 {
-       return NULL;
+       struct ast_dns_query_set *query_set;
+
+       query_set = ao2_alloc_options(sizeof(*query_set), dns_query_set_destroy, AO2_ALLOC_OPT_LOCK_NOLOCK);
+       if (!query_set) {
+               return NULL;
+       }
+
+       if (AST_VECTOR_INIT(&query_set->queries, DNS_QUERY_SET_EXPECTED_QUERY_COUNT)) {
+               ao2_ref(query_set, -1);
+               return NULL;
+       }
+
+       return query_set;
+}
+
+/*! \brief Callback invoked upon completion of a DNS query */
+static void dns_query_set_callback(const struct ast_dns_query *query)
+{
+       struct ast_dns_query_set *query_set = ast_dns_query_get_data(query);
+
+       /* The reference count of the query set is bumped here in case this query holds the last reference */
+       ao2_ref(query_set, +1);
+
+       /* Drop the query set from the query so the query set can be destroyed if this is the last one */
+       ao2_cleanup(((struct ast_dns_query *)query)->user_data);
+       ((struct ast_dns_query *)query)->user_data = NULL;
+
+       if (ast_atomic_fetchadd_int(&query_set->queries_completed, +1) != (AST_VECTOR_SIZE(&query_set->queries) - 1)) {
+               ao2_ref(query_set, -1);
+               return;
+       }
+
+       /* All queries have been completed, invoke final callback */
+       if (query_set->queries_cancelled != AST_VECTOR_SIZE(&query_set->queries)) {
+               query_set->callback(query_set);
+       }
+
+       ao2_cleanup(query_set->user_data);
+       query_set->user_data = NULL;
+
+       ao2_ref(query_set, -1);
 }
 
 int ast_dns_query_set_add(struct ast_dns_query_set *query_set, const char *name, int rr_type, int rr_class)
 {
-       return -1;
+       struct dns_query_set_query query = {
+               .started = 0,
+       };
+
+       ast_assert(!query_set->in_progress);
+       if (query_set->in_progress) {
+               ast_log(LOG_ERROR, "Attempted to add additional query to query set '%p' after resolution has started\n",
+                       query_set);
+               return -1;
+       }
+
+       /*
+        * We are intentionally passing NULL for the user data even
+        * though dns_query_set_callback() is not NULL tolerant.  Doing
+        * this avoids a circular reference chain until the queries are
+        * started.  ast_dns_query_set_resolve_async() will set the
+        * query user_data for us later when we actually kick off the
+        * queries.
+        */
+       query.query = dns_query_alloc(name, rr_type, rr_class, dns_query_set_callback, NULL);
+       if (!query.query) {
+               return -1;
+       }
+
+       if (AST_VECTOR_APPEND(&query_set->queries, query)) {
+               ao2_ref(query.query, -1);
+               return -1;
+       }
+
+       return 0;
 }
 
 size_t ast_dns_query_set_num_queries(const struct ast_dns_query_set *query_set)
 {
-       return 0;
+       return AST_VECTOR_SIZE(&query_set->queries);
 }
 
 struct ast_dns_query *ast_dns_query_set_get(const struct ast_dns_query_set *query_set, unsigned int index)
 {
-       return NULL;
+       /* Only once all queries have been completed can results be retrieved */
+       if (query_set->queries_completed != AST_VECTOR_SIZE(&query_set->queries)) {
+               return NULL;
+       }
+
+       /* If the index exceeds the number of queries... no query for you */
+       if (index >= AST_VECTOR_SIZE(&query_set->queries)) {
+               return NULL;
+       }
+
+       return AST_VECTOR_GET_ADDR(&query_set->queries, index)->query;
 }
 
 void *ast_dns_query_set_get_data(const struct ast_dns_query_set *query_set)
@@ -75,19 +165,122 @@ void *ast_dns_query_set_get_data(const struct ast_dns_query_set *query_set)
 
 void ast_dns_query_set_resolve_async(struct ast_dns_query_set *query_set, ast_dns_query_set_callback callback, void *data)
 {
+       int idx;
+
+       ast_assert(!query_set->in_progress);
+       if (query_set->in_progress) {
+               ast_log(LOG_ERROR, "Attempted to start asynchronous resolution of query set '%p' when it has already started\n",
+                       query_set);
+               return;
+       }
+
+       query_set->in_progress = 1;
        query_set->callback = callback;
        query_set->user_data = ao2_bump(data);
+
+       /*
+        * Bump the query_set ref in case all queries complete
+        * before we are done kicking them off.
+        */
+       ao2_ref(query_set, +1);
+       for (idx = 0; idx < AST_VECTOR_SIZE(&query_set->queries); ++idx) {
+               struct dns_query_set_query *query = AST_VECTOR_GET_ADDR(&query_set->queries, idx);
+
+               query->query->user_data = ao2_bump(query_set);
+
+               if (!query->query->resolver->resolve(query->query)) {
+                       query->started = 1;
+                       continue;
+               }
+
+               dns_query_set_callback(query->query);
+       }
+       if (!idx) {
+               /*
+                * There were no queries in the set;
+                * therefore all queries are "completed".
+                * Invoke the final callback.
+                */
+               query_set->callback(query_set);
+               ao2_cleanup(query_set->user_data);
+               query_set->user_data = NULL;
+       }
+       ao2_ref(query_set, -1);
 }
 
-void ast_query_set_resolve(struct ast_dns_query_set *query_set)
+/*! \brief Structure used for signaling back for synchronous resolution completion */
+struct dns_synchronous_resolve {
+       /*! \brief Lock used for signaling */
+       ast_mutex_t lock;
+       /*! \brief Condition used for signaling */
+       ast_cond_t cond;
+       /*! \brief Whether the query has completed */
+       unsigned int completed;
+};
+
+/*! \brief Destructor for synchronous resolution structure */
+static void dns_synchronous_resolve_destroy(void *data)
 {
+       struct dns_synchronous_resolve *synchronous = data;
+
+       ast_mutex_destroy(&synchronous->lock);
+       ast_cond_destroy(&synchronous->cond);
 }
 
-int ast_dns_query_set_resolve_cancel(struct ast_dns_query_set *query_set)
+/*! \brief Callback used to implement synchronous resolution */
+static void dns_synchronous_resolve_callback(const struct ast_dns_query_set *query_set)
 {
-       return -1;
+       struct dns_synchronous_resolve *synchronous = ast_dns_query_set_get_data(query_set);
+
+       ast_mutex_lock(&synchronous->lock);
+       synchronous->completed = 1;
+       ast_cond_signal(&synchronous->cond);
+       ast_mutex_unlock(&synchronous->lock);
 }
 
-void ast_dns_query_set_free(struct ast_dns_query_set *query_set)
+int ast_query_set_resolve(struct ast_dns_query_set *query_set)
 {
+       struct dns_synchronous_resolve *synchronous;
+
+       synchronous = ao2_alloc_options(sizeof(*synchronous), dns_synchronous_resolve_destroy, AO2_ALLOC_OPT_LOCK_NOLOCK);
+       if (!synchronous) {
+               return -1;
+       }
+
+       ast_mutex_init(&synchronous->lock);
+       ast_cond_init(&synchronous->cond, NULL);
+
+       ast_dns_query_set_resolve_async(query_set, dns_synchronous_resolve_callback, synchronous);
+
+       /* Wait for resolution to complete */
+       ast_mutex_lock(&synchronous->lock);
+       while (!synchronous->completed) {
+               ast_cond_wait(&synchronous->cond, &synchronous->lock);
+       }
+       ast_mutex_unlock(&synchronous->lock);
+
+       ao2_ref(synchronous, -1);
+
+       return 0;
 }
+
+int ast_dns_query_set_resolve_cancel(struct ast_dns_query_set *query_set)
+{
+       int idx;
+       size_t query_count = AST_VECTOR_SIZE(&query_set->queries);
+
+       for (idx = 0; idx < AST_VECTOR_SIZE(&query_set->queries); ++idx) {
+               struct dns_query_set_query *query = AST_VECTOR_GET_ADDR(&query_set->queries, idx);
+
+               if (query->started) {
+                       if (!query->query->resolver->cancel(query->query)) {
+                               query_set->queries_cancelled++;
+                               dns_query_set_callback(query->query);
+                       }
+               } else {
+                       query_set->queries_cancelled++;
+               }
+       }
+
+       return (query_set->queries_cancelled == query_count) ? 0 : -1;
+}
\ No newline at end of file