1 /*
   2  * Copyright (c) 2010, Oracle and/or its affiliates. All rights reserved.
   3  */
   4 
   5 /*
   6  * This file contains code imported from the OFED rds source file message.c
   7  * Oracle elects to have and use the contents of message.c under and governed
   8  * by the OpenIB.org BSD license (see below for full license text). However,
   9  * the following notice accompanied the original version of this file:
  10  */
  11 
  12 /*
  13  * Copyright (c) 2006 Oracle.  All rights reserved.
  14  *
  15  * This software is available to you under a choice of one of two
  16  * licenses.  You may choose to be licensed under the terms of the GNU
  17  * General Public License (GPL) Version 2, available from the file
  18  * COPYING in the main directory of this source tree, or the
  19  * OpenIB.org BSD license below:
  20  *
  21  *     Redistribution and use in source and binary forms, with or
  22  *     without modification, are permitted provided that the following
  23  *     conditions are met:
  24  *
  25  *      - Redistributions of source code must retain the above
  26  *        copyright notice, this list of conditions and the following
  27  *        disclaimer.
  28  *
  29  *      - Redistributions in binary form must reproduce the above
  30  *        copyright notice, this list of conditions and the following
  31  *        disclaimer in the documentation and/or other materials
  32  *        provided with the distribution.
  33  *
  34  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  35  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  36  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  37  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  38  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  39  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  40  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  41  * SOFTWARE.
  42  *
  43  */
  44 #include <sys/rds.h>
  45 
  46 #include <sys/ib/clients/rdsv3/rdsv3.h>
  47 #include <sys/ib/clients/rdsv3/rdma.h>
  48 #include <sys/ib/clients/rdsv3/rdsv3_debug.h>
  49 
  50 #ifndef __lock_lint
  51 static unsigned int     rdsv3_exthdr_size[__RDSV3_EXTHDR_MAX] = {
  52 [RDSV3_EXTHDR_NONE]     = 0,
  53 [RDSV3_EXTHDR_VERSION]  = sizeof (struct rdsv3_ext_header_version),
  54 [RDSV3_EXTHDR_RDMA]     = sizeof (struct rdsv3_ext_header_rdma),
  55 [RDSV3_EXTHDR_RDMA_DEST]        = sizeof (struct rdsv3_ext_header_rdma_dest),
  56 };
  57 #else
  58 static unsigned int     rdsv3_exthdr_size[__RDSV3_EXTHDR_MAX] = {
  59                         0,
  60                         sizeof (struct rdsv3_ext_header_version),
  61                         sizeof (struct rdsv3_ext_header_rdma),
  62                         sizeof (struct rdsv3_ext_header_rdma_dest),
  63 };
  64 #endif
  65 
  66 void
  67 rdsv3_message_addref(struct rdsv3_message *rm)
  68 {
  69         RDSV3_DPRINTF5("rdsv3_message_addref", "addref rm %p ref %d",
  70             rm, atomic_get(&rm->m_refcount));
  71         atomic_add_32(&rm->m_refcount, 1);
  72 }
  73 
  74 /*
  75  * This relies on dma_map_sg() not touching sg[].page during merging.
  76  */
  77 static void
  78 rdsv3_message_purge(struct rdsv3_message *rm)
  79 {
  80         unsigned long i;
  81 
  82         RDSV3_DPRINTF4("rdsv3_message_purge", "Enter(rm: %p)", rm);
  83 
  84         if (test_bit(RDSV3_MSG_PAGEVEC, &rm->m_flags))
  85                 return;
  86 
  87         for (i = 0; i < rm->m_nents; i++) {
  88                 RDSV3_DPRINTF5("rdsv3_message_purge", "putting data page %p\n",
  89                     (void *)rdsv3_sg_page(&rm->m_sg[i]));
  90                 /* XXX will have to put_page for page refs */
  91                 kmem_free(rdsv3_sg_page(&rm->m_sg[i]),
  92                     rdsv3_sg_len(&rm->m_sg[i]));
  93         }
  94 
  95         if (rm->m_rdma_op)
  96                 rdsv3_rdma_free_op(rm->m_rdma_op);
  97         if (rm->m_rdma_mr) {
  98                 struct rdsv3_mr *mr = rm->m_rdma_mr;
  99                 if (mr->r_refcount == 0) {
 100                         RDSV3_DPRINTF4("rdsv3_message_purge ASSERT 0",
 101                             "rm %p mr %p", rm, mr);
 102                         return;
 103                 }
 104                 if (mr->r_refcount == 0xdeadbeef) {
 105                         RDSV3_DPRINTF4("rdsv3_message_purge ASSERT deadbeef",
 106                             "rm %p mr %p", rm, mr);
 107                         return;
 108                 }
 109                 if (atomic_dec_and_test(&mr->r_refcount)) {
 110                         rm->m_rdma_mr = NULL;
 111                         __rdsv3_put_mr_final(mr);
 112                 }
 113         }
 114 
 115         RDSV3_DPRINTF4("rdsv3_message_purge", "Return(rm: %p)", rm);
 116 
 117 }
 118 
 119 void
 120 rdsv3_message_put(struct rdsv3_message *rm)
 121 {
 122         RDSV3_DPRINTF5("rdsv3_message_put",
 123             "put rm %p ref %d\n", rm, atomic_get(&rm->m_refcount));
 124 
 125         if (atomic_dec_and_test(&rm->m_refcount)) {
 126                 ASSERT(!list_link_active(&rm->m_sock_item));
 127                 ASSERT(!list_link_active(&rm->m_conn_item));
 128                 rdsv3_message_purge(rm);
 129 
 130                 kmem_free(rm, sizeof (struct rdsv3_message) +
 131                     (rm->m_nents * sizeof (struct rdsv3_scatterlist)));
 132         }
 133 }
 134 
 135 void
 136 rdsv3_message_inc_free(struct rdsv3_incoming *inc)
 137 {
 138         struct rdsv3_message *rm =
 139             container_of(inc, struct rdsv3_message, m_inc);
 140         rdsv3_message_put(rm);
 141 }
 142 
 143 void
 144 rdsv3_message_populate_header(struct rdsv3_header *hdr, uint16_be_t sport,
 145     uint16_be_t dport, uint64_t seq)
 146 {
 147         hdr->h_flags = 0;
 148         hdr->h_sport = sport;
 149         hdr->h_dport = dport;
 150         hdr->h_sequence = htonll(seq);
 151         hdr->h_exthdr[0] = RDSV3_EXTHDR_NONE;
 152 }
 153 
 154 int
 155 rdsv3_message_add_extension(struct rdsv3_header *hdr,
 156     unsigned int type, const void *data, unsigned int len)
 157 {
 158         unsigned int ext_len = sizeof (uint8_t) + len;
 159         unsigned char *dst;
 160 
 161         RDSV3_DPRINTF4("rdsv3_message_add_extension", "Enter");
 162 
 163         /* For now, refuse to add more than one extension header */
 164         if (hdr->h_exthdr[0] != RDSV3_EXTHDR_NONE)
 165                 return (0);
 166 
 167         if (type >= __RDSV3_EXTHDR_MAX ||
 168             len != rdsv3_exthdr_size[type])
 169                 return (0);
 170 
 171         if (ext_len >= RDSV3_HEADER_EXT_SPACE)
 172                 return (0);
 173         dst = hdr->h_exthdr;
 174 
 175         *dst++ = type;
 176         (void) memcpy(dst, data, len);
 177 
 178         dst[len] = RDSV3_EXTHDR_NONE;
 179 
 180         RDSV3_DPRINTF4("rdsv3_message_add_extension", "Return");
 181         return (1);
 182 }
 183 
 184 /*
 185  * If a message has extension headers, retrieve them here.
 186  * Call like this:
 187  *
 188  * unsigned int pos = 0;
 189  *
 190  * while (1) {
 191  *      buflen = sizeof(buffer);
 192  *      type = rdsv3_message_next_extension(hdr, &pos, buffer, &buflen);
 193  *      if (type == RDSV3_EXTHDR_NONE)
 194  *              break;
 195  *      ...
 196  * }
 197  */
 198 int
 199 rdsv3_message_next_extension(struct rdsv3_header *hdr,
 200     unsigned int *pos, void *buf, unsigned int *buflen)
 201 {
 202         unsigned int offset, ext_type, ext_len;
 203         uint8_t *src = hdr->h_exthdr;
 204 
 205         RDSV3_DPRINTF4("rdsv3_message_next_extension", "Enter");
 206 
 207         offset = *pos;
 208         if (offset >= RDSV3_HEADER_EXT_SPACE)
 209                 goto none;
 210 
 211         /*
 212          * Get the extension type and length. For now, the
 213          * length is implied by the extension type.
 214          */
 215         ext_type = src[offset++];
 216 
 217         if (ext_type == RDSV3_EXTHDR_NONE || ext_type >= __RDSV3_EXTHDR_MAX)
 218                 goto none;
 219         ext_len = rdsv3_exthdr_size[ext_type];
 220         if (offset + ext_len > RDSV3_HEADER_EXT_SPACE)
 221                 goto none;
 222 
 223         *pos = offset + ext_len;
 224         if (ext_len < *buflen)
 225                 *buflen = ext_len;
 226         (void) memcpy(buf, src + offset, *buflen);
 227         return (ext_type);
 228 
 229 none:
 230         *pos = RDSV3_HEADER_EXT_SPACE;
 231         *buflen = 0;
 232         return (RDSV3_EXTHDR_NONE);
 233 }
 234 
 235 int
 236 rdsv3_message_add_version_extension(struct rdsv3_header *hdr,
 237     unsigned int version)
 238 {
 239         struct rdsv3_ext_header_version ext_hdr;
 240 
 241         ext_hdr.h_version = htonl(version);
 242         return (rdsv3_message_add_extension(hdr, RDSV3_EXTHDR_VERSION,
 243             &ext_hdr, sizeof (ext_hdr)));
 244 }
 245 
 246 int
 247 rdsv3_message_get_version_extension(struct rdsv3_header *hdr,
 248     unsigned int *version)
 249 {
 250         struct rdsv3_ext_header_version ext_hdr;
 251         unsigned int pos = 0, len = sizeof (ext_hdr);
 252 
 253         RDSV3_DPRINTF4("rdsv3_message_get_version_extension", "Enter");
 254 
 255         /*
 256          * We assume the version extension is the only one present
 257          */
 258         if (rdsv3_message_next_extension(hdr, &pos, &ext_hdr, &len) !=
 259             RDSV3_EXTHDR_VERSION)
 260                 return (0);
 261         *version = ntohl(ext_hdr.h_version);
 262         return (1);
 263 }
 264 
 265 int
 266 rdsv3_message_add_rdma_dest_extension(struct rdsv3_header *hdr, uint32_t r_key,
 267     uint32_t offset)
 268 {
 269         struct rdsv3_ext_header_rdma_dest ext_hdr;
 270 
 271         ext_hdr.h_rdma_rkey = htonl(r_key);
 272         ext_hdr.h_rdma_offset = htonl(offset);
 273         return (rdsv3_message_add_extension(hdr, RDSV3_EXTHDR_RDMA_DEST,
 274             &ext_hdr, sizeof (ext_hdr)));
 275 }
 276 
 277 struct rdsv3_message *
 278 rdsv3_message_alloc(unsigned int nents, int gfp)
 279 {
 280         struct rdsv3_message *rm;
 281 
 282         RDSV3_DPRINTF4("rdsv3_message_alloc", "Enter(nents: %d)", nents);
 283 
 284         rm = kmem_zalloc(sizeof (struct rdsv3_message) +
 285             (nents * sizeof (struct rdsv3_scatterlist)), gfp);
 286         if (!rm)
 287                 goto out;
 288 
 289         rm->m_refcount = 1;
 290         list_link_init(&rm->m_sock_item);
 291         list_link_init(&rm->m_conn_item);
 292         mutex_init(&rm->m_rs_lock, NULL, MUTEX_DRIVER, NULL);
 293         rdsv3_init_waitqueue(&rm->m_flush_wait);
 294 
 295         RDSV3_DPRINTF4("rdsv3_message_alloc", "Return(rm: %p)", rm);
 296 out:
 297         return (rm);
 298 }
 299 
 300 struct rdsv3_message *
 301 rdsv3_message_map_pages(unsigned long *page_addrs, unsigned int total_len)
 302 {
 303         struct rdsv3_message *rm;
 304         unsigned int i;
 305 
 306         RDSV3_DPRINTF4("rdsv3_message_map_pages", "Enter(len: %d)", total_len);
 307 
 308 #ifndef __lock_lint
 309         rm = rdsv3_message_alloc(ceil(total_len, PAGE_SIZE), KM_NOSLEEP);
 310 #else
 311         rm = NULL;
 312 #endif
 313         if (rm == NULL)
 314                 return (ERR_PTR(-ENOMEM));
 315 
 316         set_bit(RDSV3_MSG_PAGEVEC, &rm->m_flags);
 317         rm->m_inc.i_hdr.h_len = htonl(total_len);
 318 #ifndef __lock_lint
 319         rm->m_nents = ceil(total_len, PAGE_SIZE);
 320 #else
 321         rm->m_nents = 0;
 322 #endif
 323 
 324         for (i = 0; i < rm->m_nents; ++i) {
 325                 rdsv3_sg_set_page(&rm->m_sg[i],
 326                     page_addrs[i],
 327                     PAGE_SIZE, 0);
 328         }
 329 
 330         return (rm);
 331 }
 332 
 333 struct rdsv3_message *
 334 rdsv3_message_copy_from_user(struct uio *uiop,
 335     size_t total_len)
 336 {
 337         struct rdsv3_message *rm;
 338         struct rdsv3_scatterlist *sg;
 339         int ret;
 340 
 341         RDSV3_DPRINTF4("rdsv3_message_copy_from_user", "Enter: %d", total_len);
 342 
 343 #ifndef __lock_lint
 344         rm = rdsv3_message_alloc(ceil(total_len, PAGE_SIZE), KM_NOSLEEP);
 345 #else
 346         rm = NULL;
 347 #endif
 348         if (rm == NULL) {
 349                 ret = -ENOMEM;
 350                 goto out;
 351         }
 352 
 353         rm->m_inc.i_hdr.h_len = htonl(total_len);
 354 
 355         /*
 356          * now allocate and copy in the data payload.
 357          */
 358         sg = rm->m_sg;
 359 
 360         while (total_len) {
 361                 if (rdsv3_sg_page(sg) == NULL) {
 362                         ret = rdsv3_page_remainder_alloc(sg, total_len, 0);
 363                         if (ret)
 364                                 goto out;
 365                         rm->m_nents++;
 366                 }
 367 
 368                 ret = uiomove(rdsv3_sg_page(sg), rdsv3_sg_len(sg), UIO_WRITE,
 369                     uiop);
 370                 if (ret) {
 371                         RDSV3_DPRINTF2("rdsv3_message_copy_from_user",
 372                             "uiomove failed");
 373                         ret = -ret;
 374                         goto out;
 375                 }
 376 
 377                 total_len -= rdsv3_sg_len(sg);
 378                 sg++;
 379         }
 380         ret = 0;
 381 out:
 382         if (ret) {
 383                 if (rm)
 384                         rdsv3_message_put(rm);
 385                 rm = ERR_PTR(ret);
 386         }
 387         return (rm);
 388 }
 389 
 390 int
 391 rdsv3_message_inc_copy_to_user(struct rdsv3_incoming *inc,
 392     uio_t *uiop, size_t size)
 393 {
 394         struct rdsv3_message *rm;
 395         struct rdsv3_scatterlist *sg;
 396         unsigned long to_copy;
 397         unsigned long vec_off;
 398         int copied;
 399         int ret;
 400         uint32_t len;
 401 
 402         rm = container_of(inc, struct rdsv3_message, m_inc);
 403         len = ntohl(rm->m_inc.i_hdr.h_len);
 404 
 405         RDSV3_DPRINTF4("rdsv3_message_inc_copy_to_user",
 406             "Enter(rm: %p, len: %d)", rm, len);
 407 
 408         sg = rm->m_sg;
 409         vec_off = 0;
 410         copied = 0;
 411 
 412         while (copied < size && copied < len) {
 413 
 414                 to_copy = min(len - copied, sg->length - vec_off);
 415                 to_copy = min(size - copied, to_copy);
 416 
 417                 RDSV3_DPRINTF5("rdsv3_message_inc_copy_to_user",
 418                     "copying %lu bytes to user iov %p from sg [%p, %u] + %lu\n",
 419                     to_copy, uiop,
 420                     rdsv3_sg_page(sg), sg->length, vec_off);
 421 
 422                 ret = uiomove(rdsv3_sg_page(sg), to_copy, UIO_READ, uiop);
 423                 if (ret)
 424                         break;
 425 
 426                 vec_off += to_copy;
 427                 copied += to_copy;
 428 
 429                 if (vec_off == sg->length) {
 430                         vec_off = 0;
 431                         sg++;
 432                 }
 433         }
 434 
 435         return (copied);
 436 }
 437 
 438 /*
 439  * If the message is still on the send queue, wait until the transport
 440  * is done with it. This is particularly important for RDMA operations.
 441  */
 442 /* ARGSUSED */
 443 void
 444 rdsv3_message_wait(struct rdsv3_message *rm)
 445 {
 446         rdsv3_wait_event(&rm->m_flush_wait,
 447             !test_bit(RDSV3_MSG_MAPPED, &rm->m_flags));
 448 }
 449 
 450 void
 451 rdsv3_message_unmapped(struct rdsv3_message *rm)
 452 {
 453         clear_bit(RDSV3_MSG_MAPPED, &rm->m_flags);
 454         rdsv3_wake_up_all(&rm->m_flush_wait);
 455 }