Add support for ICE/STUN/TURN in res_rtp_asterisk and chan_sip.
[asterisk/asterisk.git] / res / pjproject / pjlib-util / src / pjlib-util / dns_server.c
1 /* $Id$ */
2 /* 
3  * Copyright (C) 2008-2011 Teluu Inc. (http://www.teluu.com)
4  * Copyright (C) 2003-2008 Benny Prijono <benny@prijono.org>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA 
19  */
20 #include <pjlib-util/dns_server.h>
21 #include <pjlib-util/errno.h>
22 #include <pj/activesock.h>
23 #include <pj/assert.h>
24 #include <pj/list.h>
25 #include <pj/log.h>
26 #include <pj/pool.h>
27 #include <pj/string.h>
28
29 #define THIS_FILE   "dns_server.c"
30 #define MAX_ANS     16
31 #define MAX_PKT     1500
32 #define MAX_LABEL   32
33
34 struct label_tab
35 {
36     unsigned count;
37
38     struct {
39         unsigned pos;
40         pj_str_t label;
41     } a[MAX_LABEL];
42 };
43
44 struct rr
45 {
46     PJ_DECL_LIST_MEMBER(struct rr);
47     pj_dns_parsed_rr    rec;
48 };
49
50
51 struct pj_dns_server
52 {
53     pj_pool_t           *pool;
54     pj_pool_factory     *pf;
55     pj_activesock_t     *asock;
56     pj_ioqueue_op_key_t  send_key;
57     struct rr            rr_list;
58 };
59
60
61 static pj_bool_t on_data_recvfrom(pj_activesock_t *asock,
62                                   void *data,
63                                   pj_size_t size,
64                                   const pj_sockaddr_t *src_addr,
65                                   int addr_len,
66                                   pj_status_t status);
67
68
69 PJ_DEF(pj_status_t) pj_dns_server_create( pj_pool_factory *pf,
70                                           pj_ioqueue_t *ioqueue,
71                                           int af,
72                                           unsigned port,
73                                           unsigned flags,
74                                           pj_dns_server **p_srv)
75 {
76     pj_pool_t *pool;
77     pj_dns_server *srv;
78     pj_sockaddr sock_addr;
79     pj_activesock_cb sock_cb;
80     pj_status_t status;
81
82     PJ_ASSERT_RETURN(pf && ioqueue && p_srv && flags==0, PJ_EINVAL);
83     PJ_ASSERT_RETURN(af==pj_AF_INET() || af==pj_AF_INET6(), PJ_EINVAL);
84     
85     pool = pj_pool_create(pf, "dnsserver", 256, 256, NULL);
86     srv = (pj_dns_server*) PJ_POOL_ZALLOC_T(pool, pj_dns_server);
87     srv->pool = pool;
88     srv->pf = pf;
89     pj_list_init(&srv->rr_list);
90
91     pj_bzero(&sock_addr, sizeof(sock_addr));
92     sock_addr.addr.sa_family = (pj_uint16_t)af;
93     pj_sockaddr_set_port(&sock_addr, (pj_uint16_t)port);
94     
95     pj_bzero(&sock_cb, sizeof(sock_cb));
96     sock_cb.on_data_recvfrom = &on_data_recvfrom;
97
98     status = pj_activesock_create_udp(pool, &sock_addr, NULL, ioqueue,
99                                       &sock_cb, srv, &srv->asock, NULL);
100     if (status != PJ_SUCCESS)
101         goto on_error;
102
103     pj_ioqueue_op_key_init(&srv->send_key, sizeof(srv->send_key));
104
105     status = pj_activesock_start_recvfrom(srv->asock, pool, MAX_PKT, 0);
106     if (status != PJ_SUCCESS)
107         goto on_error;
108
109     *p_srv = srv;
110     return PJ_SUCCESS;
111
112 on_error:
113     pj_dns_server_destroy(srv);
114     return status;
115 }
116
117
118 PJ_DEF(pj_status_t) pj_dns_server_destroy(pj_dns_server *srv)
119 {
120     PJ_ASSERT_RETURN(srv, PJ_EINVAL);
121
122     if (srv->asock) {
123         pj_activesock_close(srv->asock);
124         srv->asock = NULL;
125     }
126
127     if (srv->pool) {
128         pj_pool_t *pool = srv->pool;
129         srv->pool = NULL;
130         pj_pool_release(pool);
131     }
132
133     return PJ_SUCCESS;
134 }
135
136
137 static struct rr* find_rr( pj_dns_server *srv,
138                            unsigned dns_class,
139                            unsigned type        /* pj_dns_type */,
140                            const pj_str_t *name)
141 {
142     struct rr *r;
143
144     r = srv->rr_list.next;
145     while (r != &srv->rr_list) {
146         if (r->rec.dnsclass == dns_class && r->rec.type == type && 
147             pj_stricmp(&r->rec.name, name)==0)
148         {
149             return r;
150         }
151         r = r->next;
152     }
153
154     return NULL;
155 }
156
157
158 PJ_DEF(pj_status_t) pj_dns_server_add_rec( pj_dns_server *srv,
159                                            unsigned count,
160                                            const pj_dns_parsed_rr rr_param[])
161 {
162     unsigned i;
163
164     PJ_ASSERT_RETURN(srv && count && rr_param, PJ_EINVAL);
165
166     for (i=0; i<count; ++i) {
167         struct rr *rr;
168
169         PJ_ASSERT_RETURN(find_rr(srv, rr_param[i].dnsclass, rr_param[i].type,
170                                  &rr_param[i].name) == NULL,
171                          PJ_EEXISTS);
172
173         rr = (struct rr*) PJ_POOL_ZALLOC_T(srv->pool, struct rr);
174         pj_memcpy(&rr->rec, &rr_param[i], sizeof(pj_dns_parsed_rr));
175
176         pj_list_push_back(&srv->rr_list, rr);
177     }
178
179     return PJ_SUCCESS;
180 }
181
182
183 PJ_DEF(pj_status_t) pj_dns_server_del_rec( pj_dns_server *srv,
184                                            int dns_class,
185                                            pj_dns_type type,
186                                            const pj_str_t *name)
187 {
188     struct rr *rr;
189
190     PJ_ASSERT_RETURN(srv && type && name, PJ_EINVAL);
191
192     rr = find_rr(srv, dns_class, type, name);
193     if (!rr)
194         return PJ_ENOTFOUND;
195
196     pj_list_erase(rr);
197
198     return PJ_SUCCESS;
199 }
200
201
202 static void write16(pj_uint8_t *p, pj_uint16_t val)
203 {
204     p[0] = (pj_uint8_t)(val >> 8);
205     p[1] = (pj_uint8_t)(val & 0xFF);
206 }
207
208 static void write32(pj_uint8_t *p, pj_uint32_t val)
209 {
210     val = pj_htonl(val);
211     pj_memcpy(p, &val, 4);
212 }
213
214 static int print_name(pj_uint8_t *pkt, int size,
215                       pj_uint8_t *pos, const pj_str_t *name,
216                       struct label_tab *tab)
217 {
218     pj_uint8_t *p = pos;
219     const char *endlabel, *endname;
220     unsigned i;
221     pj_str_t label;
222
223     /* Check if name is in the table */
224     for (i=0; i<tab->count; ++i) {
225         if (pj_strcmp(&tab->a[i].label, name)==0)
226             break;
227     }
228
229     if (i != tab->count) {
230         write16(p, (pj_uint16_t)(tab->a[i].pos | (0xc0 << 8)));
231         return 2;
232     } else {
233         if (tab->count < MAX_LABEL) {
234             tab->a[tab->count].pos = (p-pkt);
235             tab->a[tab->count].label.ptr = (char*)(p+1);
236             tab->a[tab->count].label.slen = name->slen;
237             ++tab->count;
238         }
239     }
240
241     endlabel = name->ptr;
242     endname = name->ptr + name->slen;
243
244     label.ptr = (char*)name->ptr;
245
246     while (endlabel != endname) {
247
248         while (endlabel != endname && *endlabel != '.')
249             ++endlabel;
250
251         label.slen = (endlabel - label.ptr);
252
253         if (size < label.slen+1)
254             return -1;
255
256         *p = (pj_uint8_t)label.slen;
257         pj_memcpy(p+1, label.ptr, label.slen);
258
259         size -= (label.slen+1);
260         p += (label.slen+1);
261
262         if (endlabel != endname && *endlabel == '.')
263             ++endlabel;
264         label.ptr = (char*)endlabel;
265     }
266
267     if (size == 0)
268         return -1;
269
270     *p++ = '\0';
271
272     return p-pos;
273 }
274
275 static int print_rr(pj_uint8_t *pkt, int size, pj_uint8_t *pos,
276                     const pj_dns_parsed_rr *rr, struct label_tab *tab)
277 {
278     pj_uint8_t *p = pos;
279     int len;
280
281     len = print_name(pkt, size, pos, &rr->name, tab);
282     if (len < 0)
283         return -1;
284
285     p += len;
286     size -= len;
287
288     if (size < 8)
289         return -1;
290
291     pj_assert(rr->dnsclass == 1);
292
293     write16(p+0, (pj_uint16_t)rr->type);        /* type     */
294     write16(p+2, (pj_uint16_t)rr->dnsclass);    /* class    */
295     write32(p+4, rr->ttl);                      /* TTL      */
296
297     p += 8;
298     size -= 8;
299
300     if (rr->type == PJ_DNS_TYPE_A) {
301
302         if (size < 6)
303             return -1;
304
305         /* RDLEN is 4 */
306         write16(p, 4);
307
308         /* Address */
309         pj_memcpy(p+2, &rr->rdata.a.ip_addr, 4);
310
311         p += 6;
312         size -= 6;
313
314     } else if (rr->type == PJ_DNS_TYPE_CNAME ||
315                rr->type == PJ_DNS_TYPE_NS ||
316                rr->type == PJ_DNS_TYPE_PTR) {
317
318         if (size < 4)
319             return -1;
320
321         len = print_name(pkt, size-2, p+2, &rr->rdata.cname.name, tab);
322         if (len < 0)
323             return -1;
324
325         write16(p, (pj_uint16_t)len);
326
327         p += (len + 2);
328         size -= (len + 2);
329
330     } else if (rr->type == PJ_DNS_TYPE_SRV) {
331
332         if (size < 10)
333             return -1;
334
335         write16(p+2, rr->rdata.srv.prio);   /* Priority */
336         write16(p+4, rr->rdata.srv.weight); /* Weight */
337         write16(p+6, rr->rdata.srv.port);   /* Port */
338
339         /* Target */
340         len = print_name(pkt, size-8, p+8, &rr->rdata.srv.target, tab);
341         if (len < 0)
342             return -1;
343
344         /* RDLEN */
345         write16(p, (pj_uint16_t)(len + 6));
346
347         p += (len + 8);
348         size -= (len + 8);
349
350     } else {
351         pj_assert(!"Not supported");
352         return -1;
353     }
354
355     return p-pos;
356 }
357
358 static int print_packet(const pj_dns_parsed_packet *rec, pj_uint8_t *pkt,
359                         int size)
360 {
361     pj_uint8_t *p = pkt;
362     struct label_tab tab;
363     int i, len;
364
365     tab.count = 0;
366
367     pj_assert(sizeof(pj_dns_hdr)==12);
368     if (size < (int)sizeof(pj_dns_hdr))
369         return -1;
370
371     /* Initialize header */
372     write16(p+0,  rec->hdr.id);
373     write16(p+2,  rec->hdr.flags);
374     write16(p+4,  rec->hdr.qdcount);
375     write16(p+6,  rec->hdr.anscount);
376     write16(p+8,  rec->hdr.nscount);
377     write16(p+10, rec->hdr.arcount);
378
379     p = pkt + sizeof(pj_dns_hdr);
380     size -= sizeof(pj_dns_hdr);
381
382     /* Print queries */
383     for (i=0; i<rec->hdr.qdcount; ++i) {
384
385         len = print_name(pkt, size, p, &rec->q[i].name, &tab);
386         if (len < 0)
387             return -1;
388
389         p += len;
390         size -= len;
391
392         if (size < 4)
393             return -1;
394
395         /* Set type */
396         write16(p+0, (pj_uint16_t)rec->q[i].type);
397
398         /* Set class (IN=1) */
399         pj_assert(rec->q[i].dnsclass == 1);
400         write16(p+2, rec->q[i].dnsclass);
401
402         p += 4;
403     }
404
405     /* Print answers */
406     for (i=0; i<rec->hdr.anscount; ++i) {
407         len = print_rr(pkt, size, p, &rec->ans[i], &tab);
408         if (len < 0)
409             return -1;
410
411         p += len;
412         size -= len;
413     }
414
415     /* Print NS records */
416     for (i=0; i<rec->hdr.nscount; ++i) {
417         len = print_rr(pkt, size, p, &rec->ns[i], &tab);
418         if (len < 0)
419             return -1;
420
421         p += len;
422         size -= len;
423     }
424
425     /* Print additional records */
426     for (i=0; i<rec->hdr.arcount; ++i) {
427         len = print_rr(pkt, size, p, &rec->arr[i], &tab);
428         if (len < 0)
429             return -1;
430
431         p += len;
432         size -= len;
433     }
434
435     return p - pkt;
436 }
437
438
439 static pj_bool_t on_data_recvfrom(pj_activesock_t *asock,
440                                   void *data,
441                                   pj_size_t size,
442                                   const pj_sockaddr_t *src_addr,
443                                   int addr_len,
444                                   pj_status_t status)
445 {
446     pj_dns_server *srv;
447     pj_pool_t *pool;
448     pj_dns_parsed_packet *req;
449     pj_dns_parsed_packet ans;
450     struct rr *rr;
451     pj_ssize_t pkt_len;
452     unsigned i;
453
454     if (status != PJ_SUCCESS)
455         return PJ_TRUE;
456
457     srv = (pj_dns_server*) pj_activesock_get_user_data(asock);
458     pool = pj_pool_create(srv->pf, "dnssrvrx", 512, 256, NULL);
459
460     status = pj_dns_parse_packet(pool, data, size, &req);
461     if (status != PJ_SUCCESS) {
462         char addrinfo[PJ_INET6_ADDRSTRLEN+10];
463         pj_sockaddr_print(src_addr, addrinfo, sizeof(addrinfo), 3);
464         PJ_LOG(4,(THIS_FILE, "Error parsing query from %s", addrinfo));
465         goto on_return;
466     }
467
468     /* Init answer */
469     pj_bzero(&ans, sizeof(ans));
470     ans.hdr.id = req->hdr.id;
471     ans.hdr.qdcount = 1;
472     ans.q = (pj_dns_parsed_query*) PJ_POOL_ALLOC_T(pool, pj_dns_parsed_query);
473     pj_memcpy(ans.q, req->q, sizeof(pj_dns_parsed_query));
474
475     if (req->hdr.qdcount != 1) {
476         ans.hdr.flags = PJ_DNS_SET_RCODE(PJ_DNS_RCODE_FORMERR);
477         goto send_pkt;
478     }
479
480     if (req->q[0].dnsclass != PJ_DNS_CLASS_IN) {
481         ans.hdr.flags = PJ_DNS_SET_RCODE(PJ_DNS_RCODE_NOTIMPL);
482         goto send_pkt;
483     }
484
485     /* Find the record */
486     rr = find_rr(srv, req->q->dnsclass, req->q->type, &req->q->name);
487     if (rr == NULL) {
488         ans.hdr.flags = PJ_DNS_SET_RCODE(PJ_DNS_RCODE_NXDOMAIN);
489         goto send_pkt;
490     }
491
492     /* Init answer record */
493     ans.hdr.anscount = 0;
494     ans.ans = (pj_dns_parsed_rr*)
495               pj_pool_calloc(pool, MAX_ANS, sizeof(pj_dns_parsed_rr));
496
497     /* DNS SRV query needs special treatment since it returns multiple
498      * records
499      */
500     if (req->q->type == PJ_DNS_TYPE_SRV) {
501         struct rr *r;
502
503         r = srv->rr_list.next;
504         while (r != &srv->rr_list) {
505             if (r->rec.dnsclass == req->q->dnsclass && 
506                 r->rec.type == PJ_DNS_TYPE_SRV && 
507                 pj_stricmp(&r->rec.name, &req->q->name)==0 &&
508                 ans.hdr.anscount < MAX_ANS)
509             {
510                 pj_memcpy(&ans.ans[ans.hdr.anscount], &r->rec,
511                           sizeof(pj_dns_parsed_rr));
512                 ++ans.hdr.anscount;
513             }
514             r = r->next;
515         }
516     } else {
517         /* Otherwise just copy directly from the server record */
518         pj_memcpy(&ans.ans[ans.hdr.anscount], &rr->rec,
519                           sizeof(pj_dns_parsed_rr));
520         ++ans.hdr.anscount;
521     }
522
523     /* For each CNAME entry, add A entry */
524     for (i=0; i<ans.hdr.anscount && ans.hdr.anscount < MAX_ANS; ++i) {
525         if (ans.ans[i].type == PJ_DNS_TYPE_CNAME) {
526             struct rr *r;
527
528             r = find_rr(srv, ans.ans[i].dnsclass, PJ_DNS_TYPE_A,
529                         &ans.ans[i].name);
530             pj_memcpy(&ans.ans[ans.hdr.anscount], &r->rec,
531                               sizeof(pj_dns_parsed_rr));
532             ++ans.hdr.anscount;
533         }
534     }
535
536 send_pkt:
537     pkt_len = print_packet(&ans, (pj_uint8_t*)data, MAX_PKT);
538     if (pkt_len < 1) {
539         PJ_LOG(4,(THIS_FILE, "Error: answer too large"));
540         goto on_return;
541     }
542
543     status = pj_activesock_sendto(srv->asock, &srv->send_key, data, &pkt_len,
544                                   0, src_addr, addr_len);
545     if (status != PJ_SUCCESS && status != PJ_EPENDING) {
546         PJ_LOG(4,(THIS_FILE, "Error sending answer, status=%d", status));
547         goto on_return;
548     }
549
550 on_return:
551     pj_pool_release(pool);
552     return PJ_TRUE;
553 }
554