aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2023-07-03 15:38:26 -0700
committerLinus Torvalds <torvalds@linux-foundation.org>2023-07-03 15:38:26 -0700
commita8d70602b186f3c347e62c59a418be802b71886d (patch)
tree48bf9b05703ff824a4dddfaaa773687c9fe6fd05 /drivers/vhost
parente8069f5a8e3bdb5fdeeff895780529388592ee7a (diff)
parent9e396a2f434f829fb3b98a24bb8db5429320589d (diff)
downloadlinux-a8d70602b186f3c347e62c59a418be802b71886d.tar.gz
Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
Pull virtio updates from Michael Tsirkin: - resume support in vdpa/solidrun - structure size optimizations in virtio_pci - new pds_vdpa driver - immediate initialization mechanism for vdpa/ifcvf - interrupt bypass for vdpa/mlx5 - multiple worker support for vhost - viirtio net in Intel F2000X-PL support for vdpa/ifcvf - fixes, cleanups all over the place * tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost: (48 commits) vhost: Make parameter name match of vhost_get_vq_desc() vduse: fix NULL pointer dereference vhost: Allow worker switching while work is queueing vhost_scsi: add support for worker ioctls vhost: allow userspace to create workers vhost: replace single worker pointer with xarray vhost: add helper to parse userspace vring state/file vhost: remove vhost_work_queue vhost_scsi: flush IO vqs then send TMF rsp vhost_scsi: convert to vhost_vq_work_queue vhost_scsi: make SCSI cmd completion per vq vhost_sock: convert to vhost_vq_work_queue vhost: convert poll work to be vq based vhost: take worker or vq for flushing vhost: take worker or vq instead of dev for queueing vhost, vhost_net: add helper to check if vq has work vhost: add vhost_worker pointer to vhost_virtqueue vhost: dynamically allocate vhost_worker vhost: create worker at end of vhost_dev_set_owner virtio_bt: call scheduler when we free unused buffs ...
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c8
-rw-r--r--drivers/vhost/scsi.c103
-rw-r--r--drivers/vhost/vhost.c419
-rw-r--r--drivers/vhost/vhost.h24
-rw-r--r--drivers/vhost/vsock.c4
5 files changed, 449 insertions, 109 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index ae2273196b0c90..f2ed7167c84809 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -546,7 +546,7 @@ static void vhost_net_busy_poll(struct vhost_net *net,
endtime = busy_clock() + busyloop_timeout;
while (vhost_can_busy_poll(endtime)) {
- if (vhost_has_work(&net->dev)) {
+ if (vhost_vq_has_work(vq)) {
*busyloop_intr = true;
break;
}
@@ -1347,8 +1347,10 @@ static int vhost_net_open(struct inode *inode, struct file *f)
VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT, true,
NULL);
- vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
- vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
+ vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev,
+ vqs[VHOST_NET_VQ_TX]);
+ vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev,
+ vqs[VHOST_NET_VQ_RX]);
f->private_data = n;
n->page_frag.page = NULL;
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index bb10fa4bb4f6ec..c83f7f043470d6 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -167,6 +167,7 @@ MODULE_PARM_DESC(max_io_vqs, "Set the max number of IO virtqueues a vhost scsi d
struct vhost_scsi_virtqueue {
struct vhost_virtqueue vq;
+ struct vhost_scsi *vs;
/*
* Reference counting for inflight reqs, used for flush operation. At
* each time, one reference tracks new commands submitted, while we
@@ -181,6 +182,9 @@ struct vhost_scsi_virtqueue {
struct vhost_scsi_cmd *scsi_cmds;
struct sbitmap scsi_tags;
int max_cmds;
+
+ struct vhost_work completion_work;
+ struct llist_head completion_list;
};
struct vhost_scsi {
@@ -190,12 +194,8 @@ struct vhost_scsi {
struct vhost_dev dev;
struct vhost_scsi_virtqueue *vqs;
- unsigned long *compl_bitmap;
struct vhost_scsi_inflight **old_inflight;
- struct vhost_work vs_completion_work; /* cmd completion work item */
- struct llist_head vs_completion_list; /* cmd completion queue */
-
struct vhost_work vs_event_work; /* evt injection work item */
struct llist_head vs_event_list; /* evt injection queue */
@@ -353,15 +353,17 @@ static void vhost_scsi_release_cmd(struct se_cmd *se_cmd)
if (se_cmd->se_cmd_flags & SCF_SCSI_TMR_CDB) {
struct vhost_scsi_tmf *tmf = container_of(se_cmd,
struct vhost_scsi_tmf, se_cmd);
+ struct vhost_virtqueue *vq = &tmf->svq->vq;
- vhost_work_queue(&tmf->vhost->dev, &tmf->vwork);
+ vhost_vq_work_queue(vq, &tmf->vwork);
} else {
struct vhost_scsi_cmd *cmd = container_of(se_cmd,
struct vhost_scsi_cmd, tvc_se_cmd);
- struct vhost_scsi *vs = cmd->tvc_vhost;
+ struct vhost_scsi_virtqueue *svq = container_of(cmd->tvc_vq,
+ struct vhost_scsi_virtqueue, vq);
- llist_add(&cmd->tvc_completion_list, &vs->vs_completion_list);
- vhost_work_queue(&vs->dev, &vs->vs_completion_work);
+ llist_add(&cmd->tvc_completion_list, &svq->completion_list);
+ vhost_vq_work_queue(&svq->vq, &svq->completion_work);
}
}
@@ -509,17 +511,17 @@ static void vhost_scsi_evt_work(struct vhost_work *work)
*/
static void vhost_scsi_complete_cmd_work(struct vhost_work *work)
{
- struct vhost_scsi *vs = container_of(work, struct vhost_scsi,
- vs_completion_work);
+ struct vhost_scsi_virtqueue *svq = container_of(work,
+ struct vhost_scsi_virtqueue, completion_work);
struct virtio_scsi_cmd_resp v_rsp;
struct vhost_scsi_cmd *cmd, *t;
struct llist_node *llnode;
struct se_cmd *se_cmd;
struct iov_iter iov_iter;
- int ret, vq;
+ bool signal = false;
+ int ret;
- bitmap_zero(vs->compl_bitmap, vs->dev.nvqs);
- llnode = llist_del_all(&vs->vs_completion_list);
+ llnode = llist_del_all(&svq->completion_list);
llist_for_each_entry_safe(cmd, t, llnode, tvc_completion_list) {
se_cmd = &cmd->tvc_se_cmd;
@@ -539,21 +541,17 @@ static void vhost_scsi_complete_cmd_work(struct vhost_work *work)
cmd->tvc_in_iovs, sizeof(v_rsp));
ret = copy_to_iter(&v_rsp, sizeof(v_rsp), &iov_iter);
if (likely(ret == sizeof(v_rsp))) {
- struct vhost_scsi_virtqueue *q;
+ signal = true;
+
vhost_add_used(cmd->tvc_vq, cmd->tvc_vq_desc, 0);
- q = container_of(cmd->tvc_vq, struct vhost_scsi_virtqueue, vq);
- vq = q - vs->vqs;
- __set_bit(vq, vs->compl_bitmap);
} else
pr_err("Faulted on virtio_scsi_cmd_resp\n");
vhost_scsi_release_cmd_res(se_cmd);
}
- vq = -1;
- while ((vq = find_next_bit(vs->compl_bitmap, vs->dev.nvqs, vq + 1))
- < vs->dev.nvqs)
- vhost_signal(&vs->dev, &vs->vqs[vq].vq);
+ if (signal)
+ vhost_signal(&svq->vs->dev, &svq->vq);
}
static struct vhost_scsi_cmd *
@@ -1135,12 +1133,27 @@ static void vhost_scsi_tmf_resp_work(struct vhost_work *work)
{
struct vhost_scsi_tmf *tmf = container_of(work, struct vhost_scsi_tmf,
vwork);
- int resp_code;
+ struct vhost_virtqueue *ctl_vq, *vq;
+ int resp_code, i;
+
+ if (tmf->scsi_resp == TMR_FUNCTION_COMPLETE) {
+ /*
+ * Flush IO vqs that don't share a worker with the ctl to make
+ * sure they have sent their responses before us.
+ */
+ ctl_vq = &tmf->vhost->vqs[VHOST_SCSI_VQ_CTL].vq;
+ for (i = VHOST_SCSI_VQ_IO; i < tmf->vhost->dev.nvqs; i++) {
+ vq = &tmf->vhost->vqs[i].vq;
+
+ if (vhost_vq_is_setup(vq) &&
+ vq->worker != ctl_vq->worker)
+ vhost_vq_flush(vq);
+ }
- if (tmf->scsi_resp == TMR_FUNCTION_COMPLETE)
resp_code = VIRTIO_SCSI_S_FUNCTION_SUCCEEDED;
- else
+ } else {
resp_code = VIRTIO_SCSI_S_FUNCTION_REJECTED;
+ }
vhost_scsi_send_tmf_resp(tmf->vhost, &tmf->svq->vq, tmf->in_iovs,
tmf->vq_desc, &tmf->resp_iov, resp_code);
@@ -1335,11 +1348,9 @@ static void vhost_scsi_ctl_handle_kick(struct vhost_work *work)
}
static void
-vhost_scsi_send_evt(struct vhost_scsi *vs,
- struct vhost_scsi_tpg *tpg,
- struct se_lun *lun,
- u32 event,
- u32 reason)
+vhost_scsi_send_evt(struct vhost_scsi *vs, struct vhost_virtqueue *vq,
+ struct vhost_scsi_tpg *tpg, struct se_lun *lun,
+ u32 event, u32 reason)
{
struct vhost_scsi_evt *evt;
@@ -1361,7 +1372,7 @@ vhost_scsi_send_evt(struct vhost_scsi *vs,
}
llist_add(&evt->list, &vs->vs_event_list);
- vhost_work_queue(&vs->dev, &vs->vs_event_work);
+ vhost_vq_work_queue(vq, &vs->vs_event_work);
}
static void vhost_scsi_evt_handle_kick(struct vhost_work *work)
@@ -1375,7 +1386,8 @@ static void vhost_scsi_evt_handle_kick(struct vhost_work *work)
goto out;
if (vs->vs_events_missed)
- vhost_scsi_send_evt(vs, NULL, NULL, VIRTIO_SCSI_T_NO_EVENT, 0);
+ vhost_scsi_send_evt(vs, vq, NULL, NULL, VIRTIO_SCSI_T_NO_EVENT,
+ 0);
out:
mutex_unlock(&vq->mutex);
}
@@ -1770,6 +1782,7 @@ static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)
static int vhost_scsi_open(struct inode *inode, struct file *f)
{
+ struct vhost_scsi_virtqueue *svq;
struct vhost_scsi *vs;
struct vhost_virtqueue **vqs;
int r = -ENOMEM, i, nvqs = vhost_scsi_max_io_vqs;
@@ -1788,10 +1801,6 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
}
nvqs += VHOST_SCSI_VQ_IO;
- vs->compl_bitmap = bitmap_alloc(nvqs, GFP_KERNEL);
- if (!vs->compl_bitmap)
- goto err_compl_bitmap;
-
vs->old_inflight = kmalloc_array(nvqs, sizeof(*vs->old_inflight),
GFP_KERNEL | __GFP_ZERO);
if (!vs->old_inflight)
@@ -1806,7 +1815,6 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
if (!vqs)
goto err_local_vqs;
- vhost_work_init(&vs->vs_completion_work, vhost_scsi_complete_cmd_work);
vhost_work_init(&vs->vs_event_work, vhost_scsi_evt_work);
vs->vs_events_nr = 0;
@@ -1817,8 +1825,14 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
vs->vqs[VHOST_SCSI_VQ_CTL].vq.handle_kick = vhost_scsi_ctl_handle_kick;
vs->vqs[VHOST_SCSI_VQ_EVT].vq.handle_kick = vhost_scsi_evt_handle_kick;
for (i = VHOST_SCSI_VQ_IO; i < nvqs; i++) {
- vqs[i] = &vs->vqs[i].vq;
- vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
+ svq = &vs->vqs[i];
+
+ vqs[i] = &svq->vq;
+ svq->vs = vs;
+ init_llist_head(&svq->completion_list);
+ vhost_work_init(&svq->completion_work,
+ vhost_scsi_complete_cmd_work);
+ svq->vq.handle_kick = vhost_scsi_handle_kick;
}
vhost_dev_init(&vs->dev, vqs, nvqs, UIO_MAXIOV,
VHOST_SCSI_WEIGHT, 0, true, NULL);
@@ -1833,8 +1847,6 @@ err_local_vqs:
err_vqs:
kfree(vs->old_inflight);
err_inflight:
- bitmap_free(vs->compl_bitmap);
-err_compl_bitmap:
kvfree(vs);
err_vs:
return r;
@@ -1854,7 +1866,6 @@ static int vhost_scsi_release(struct inode *inode, struct file *f)
kfree(vs->dev.vqs);
kfree(vs->vqs);
kfree(vs->old_inflight);
- bitmap_free(vs->compl_bitmap);
kvfree(vs);
return 0;
}
@@ -1916,6 +1927,14 @@ vhost_scsi_ioctl(struct file *f,
if (copy_from_user(&features, featurep, sizeof features))
return -EFAULT;
return vhost_scsi_set_features(vs, features);
+ case VHOST_NEW_WORKER:
+ case VHOST_FREE_WORKER:
+ case VHOST_ATTACH_VRING_WORKER:
+ case VHOST_GET_VRING_WORKER:
+ mutex_lock(&vs->dev.mutex);
+ r = vhost_worker_ioctl(&vs->dev, ioctl, argp);
+ mutex_unlock(&vs->dev.mutex);
+ return r;
default:
mutex_lock(&vs->dev.mutex);
r = vhost_dev_ioctl(&vs->dev, ioctl, argp);
@@ -1995,7 +2014,7 @@ vhost_scsi_do_plug(struct vhost_scsi_tpg *tpg,
goto unlock;
if (vhost_has_feature(vq, VIRTIO_SCSI_F_HOTPLUG))
- vhost_scsi_send_evt(vs, tpg, lun,
+ vhost_scsi_send_evt(vs, vq, tpg, lun,
VIRTIO_SCSI_T_TRANSPORT_RESET, reason);
unlock:
mutex_unlock(&vq->mutex);
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 60c9ebd629dd15..c71d573f1c9497 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -187,13 +187,15 @@ EXPORT_SYMBOL_GPL(vhost_work_init);
/* Init poll structure */
void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
- __poll_t mask, struct vhost_dev *dev)
+ __poll_t mask, struct vhost_dev *dev,
+ struct vhost_virtqueue *vq)
{
init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
init_poll_funcptr(&poll->table, vhost_poll_func);
poll->mask = mask;
poll->dev = dev;
poll->wqh = NULL;
+ poll->vq = vq;
vhost_work_init(&poll->work, fn);
}
@@ -231,46 +233,102 @@ void vhost_poll_stop(struct vhost_poll *poll)
}
EXPORT_SYMBOL_GPL(vhost_poll_stop);
-void vhost_dev_flush(struct vhost_dev *dev)
+static void vhost_worker_queue(struct vhost_worker *worker,
+ struct vhost_work *work)
+{
+ if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
+ /* We can only add the work to the list after we're
+ * sure it was not in the list.
+ * test_and_set_bit() implies a memory barrier.
+ */
+ llist_add(&work->node, &worker->work_list);
+ vhost_task_wake(worker->vtsk);
+ }
+}
+
+bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
+{
+ struct vhost_worker *worker;
+ bool queued = false;
+
+ rcu_read_lock();
+ worker = rcu_dereference(vq->worker);
+ if (worker) {
+ queued = true;
+ vhost_worker_queue(worker, work);
+ }
+ rcu_read_unlock();
+
+ return queued;
+}
+EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
+
+void vhost_vq_flush(struct vhost_virtqueue *vq)
{
struct vhost_flush_struct flush;
- if (dev->worker.vtsk) {
- init_completion(&flush.wait_event);
- vhost_work_init(&flush.work, vhost_flush_work);
+ init_completion(&flush.wait_event);
+ vhost_work_init(&flush.work, vhost_flush_work);
- vhost_work_queue(dev, &flush.work);
+ if (vhost_vq_work_queue(vq, &flush.work))
wait_for_completion(&flush.wait_event);
- }
}
-EXPORT_SYMBOL_GPL(vhost_dev_flush);
+EXPORT_SYMBOL_GPL(vhost_vq_flush);
-void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
+/**
+ * vhost_worker_flush - flush a worker
+ * @worker: worker to flush
+ *
+ * This does not use RCU to protect the worker, so the device or worker
+ * mutex must be held.
+ */
+static void vhost_worker_flush(struct vhost_worker *worker)
{
- if (!dev->worker.vtsk)
- return;
+ struct vhost_flush_struct flush;
- if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
- /* We can only add the work to the list after we're
- * sure it was not in the list.
- * test_and_set_bit() implies a memory barrier.
- */
- llist_add(&work->node, &dev->worker.work_list);
- vhost_task_wake(dev->worker.vtsk);
+ init_completion(&flush.wait_event);
+ vhost_work_init(&flush.work, vhost_flush_work);
+
+ vhost_worker_queue(worker, &flush.work);
+ wait_for_completion(&flush.wait_event);
+}
+
+void vhost_dev_flush(struct vhost_dev *dev)
+{
+ struct vhost_worker *worker;
+ unsigned long i;
+
+ xa_for_each(&dev->worker_xa, i, worker) {
+ mutex_lock(&worker->mutex);
+ if (!worker->attachment_cnt) {
+ mutex_unlock(&worker->mutex);
+ continue;
+ }
+ vhost_worker_flush(worker);
+ mutex_unlock(&worker->mutex);
}
}
-EXPORT_SYMBOL_GPL(vhost_work_queue);
+EXPORT_SYMBOL_GPL(vhost_dev_flush);
/* A lockless hint for busy polling code to exit the loop */
-bool vhost_has_work(struct vhost_dev *dev)
+bool vhost_vq_has_work(struct vhost_virtqueue *vq)
{
- return !llist_empty(&dev->worker.work_list);
+ struct vhost_worker *worker;
+ bool has_work = false;
+
+ rcu_read_lock();
+ worker = rcu_dereference(vq->worker);
+ if (worker && !llist_empty(&worker->work_list))
+ has_work = true;
+ rcu_read_unlock();
+
+ return has_work;
}
-EXPORT_SYMBOL_GPL(vhost_has_work);
+EXPORT_SYMBOL_GPL(vhost_vq_has_work);
void vhost_poll_queue(struct vhost_poll *poll)
{
- vhost_work_queue(poll->dev, &poll->work);
+ vhost_vq_work_queue(poll->vq, &poll->work);
}
EXPORT_SYMBOL_GPL(vhost_poll_queue);
@@ -329,6 +387,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->busyloop_timeout = 0;
vq->umem = NULL;
vq->iotlb = NULL;
+ rcu_assign_pointer(vq->worker, NULL);
vhost_vring_call_reset(&vq->call_ctx);
__vhost_vq_meta_reset(vq);
}
@@ -458,8 +517,6 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->umem = NULL;
dev->iotlb = NULL;
dev->mm = NULL;
- memset(&dev->worker, 0, sizeof(dev->worker));
- init_llist_head(&dev->worker.work_list);
dev->iov_limit = iov_limit;
dev->weight = weight;
dev->byte_weight = byte_weight;
@@ -469,7 +526,7 @@ void vhost_dev_init(struct vhost_dev *dev,
INIT_LIST_HEAD(&dev->read_list);
INIT_LIST_HEAD(&dev->pending_list);
spin_lock_init(&dev->iotlb_lock);
-
+ xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
for (i = 0; i < dev->nvqs; ++i) {
vq = dev->vqs[i];
@@ -481,7 +538,7 @@ void vhost_dev_init(struct vhost_dev *dev,
vhost_vq_reset(dev, vq);
if (vq->handle_kick)
vhost_poll_init(&vq->poll, vq->handle_kick,
- EPOLLIN, dev);
+ EPOLLIN, dev, vq);
}
}
EXPORT_SYMBOL_GPL(vhost_dev_init);
@@ -531,38 +588,284 @@ static void vhost_detach_mm(struct vhost_dev *dev)
dev->mm = NULL;
}
-static void vhost_worker_free(struct vhost_dev *dev)
+static void vhost_worker_destroy(struct vhost_dev *dev,
+ struct vhost_worker *worker)
+{
+ if (!worker)
+ return;
+
+ WARN_ON(!llist_empty(&worker->work_list));
+ xa_erase(&dev->worker_xa, worker->id);
+ vhost_task_stop(worker->vtsk);
+ kfree(worker);
+}
+
+static void vhost_workers_free(struct vhost_dev *dev)
{
- if (!dev->worker.vtsk)
+ struct vhost_worker *worker;
+ unsigned long i;
+
+ if (!dev->use_worker)
return;
- WARN_ON(!llist_empty(&dev->worker.work_list));
- vhost_task_stop(dev->worker.vtsk);
- dev->worker.kcov_handle = 0;
- dev->worker.vtsk = NULL;
+ for (i = 0; i < dev->nvqs; i++)
+ rcu_assign_pointer(dev->vqs[i]->worker, NULL);
+ /*
+ * Free the default worker we created and cleanup workers userspace
+ * created but couldn't clean up (it forgot or crashed).
+ */
+ xa_for_each(&dev->worker_xa, i, worker)
+ vhost_worker_destroy(dev, worker);
+ xa_destroy(&dev->worker_xa);
}
-static int vhost_worker_create(struct vhost_dev *dev)
+static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
{
+ struct vhost_worker *worker;
struct vhost_task *vtsk;
char name[TASK_COMM_LEN];
+ int ret;
+ u32 id;
+
+ worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
+ if (!worker)
+ return NULL;
snprintf(name, sizeof(name), "vhost-%d", current->pid);
- vtsk = vhost_task_create(vhost_worker, &dev->worker, name);
+ vtsk = vhost_task_create(vhost_worker, worker, name);
if (!vtsk)
- return -ENOMEM;
+ goto free_worker;
+
+ mutex_init(&worker->mutex);
+ init_llist_head(&worker->work_list);
+ worker->kcov_handle = kcov_common_handle();
+ worker->vtsk = vtsk;
- dev->worker.kcov_handle = kcov_common_handle();
- dev->worker.vtsk = vtsk;
vhost_task_start(vtsk);
+
+ ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
+ if (ret < 0)
+ goto stop_worker;
+ worker->id = id;
+
+ return worker;
+
+stop_worker:
+ vhost_task_stop(vtsk);
+free_worker:
+ kfree(worker);
+ return NULL;
+}
+
+/* Caller must have device mutex */
+static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
+ struct vhost_worker *worker)
+{
+ struct vhost_worker *old_worker;
+
+ old_worker = rcu_dereference_check(vq->worker,
+ lockdep_is_held(&vq->dev->mutex));
+
+ mutex_lock(&worker->mutex);
+ worker->attachment_cnt++;
+ mutex_unlock(&worker->mutex);
+ rcu_assign_pointer(vq->worker, worker);
+
+ if (!old_worker)
+ return;
+ /*
+ * Take the worker mutex to make sure we see the work queued from
+ * device wide flushes which doesn't use RCU for execution.
+ */
+ mutex_lock(&old_worker->mutex);
+ old_worker->attachment_cnt--;
+ /*
+ * We don't want to call synchronize_rcu for every vq during setup
+ * because it will slow down VM startup. If we haven't done
+ * VHOST_SET_VRING_KICK and not done the driver specific
+ * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will
+ * not be any works queued for scsi and net.
+ */
+ mutex_lock(&vq->mutex);
+ if (!vhost_vq_get_backend(vq) && !vq->kick) {
+ mutex_unlock(&vq->mutex);
+ mutex_unlock(&old_worker->mutex);
+ /*
+ * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
+ * Warn if it adds support for multiple workers but forgets to
+ * handle the early queueing case.
+ */
+ WARN_ON(!old_worker->attachment_cnt &&
+ !llist_empty(&old_worker->work_list));
+ return;
+ }
+ mutex_unlock(&vq->mutex);
+
+ /* Make sure new vq queue/flush/poll calls see the new worker */
+ synchronize_rcu();
+ /* Make sure whatever was queued gets run */
+ vhost_worker_flush(old_worker);
+ mutex_unlock(&old_worker->mutex);
+}
+
+ /* Caller must have device mutex */
+static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
+ struct vhost_vring_worker *info)
+{
+ unsigned long index = info->worker_id;
+ struct vhost_dev *dev = vq->dev;
+ struct vhost_worker *worker;
+
+ if (!dev->use_worker)
+ return -EINVAL;
+
+ worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
+ if (!worker || worker->id != info->worker_id)
+ return -ENODEV;
+
+ __vhost_vq_attach_worker(vq, worker);
+ return 0;
+}
+
+/* Caller must have device mutex */
+static int vhost_new_worker(struct vhost_dev *dev,
+ struct vhost_worker_state *info)
+{
+ struct vhost_worker *worker;
+
+ worker = vhost_worker_create(dev);
+ if (!worker)
+ return -ENOMEM;
+
+ info->worker_id = worker->id;
+ return 0;
+}
+
+/* Caller must have device mutex */
+static int vhost_free_worker(struct vhost_dev *dev,
+ struct vhost_worker_state *info)
+{
+ unsigned long index = info->worker_id;
+ struct vhost_worker *worker;
+
+ worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
+ if (!worker || worker->id != info->worker_id)
+ return -ENODEV;
+
+ mutex_lock(&worker->mutex);
+ if (worker->attachment_cnt) {
+ mutex_unlock(&worker->mutex);
+ return -EBUSY;
+ }
+ mutex_unlock(&worker->mutex);
+
+ vhost_worker_destroy(dev, worker);
return 0;
}
+static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp,
+ struct vhost_virtqueue **vq, u32 *id)
+{
+ u32 __user *idxp = argp;
+ u32 idx;
+ long r;
+
+ r = get_user(idx, idxp);
+ if (r < 0)
+ return r;
+
+ if (idx >= dev->nvqs)
+ return -ENOBUFS;
+
+ idx = array_index_nospec(idx, dev->nvqs);
+
+ *vq = dev->vqs[idx];
+ *id = idx;
+ return 0;
+}
+
+/* Caller must have device mutex */
+long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
+ void __user *argp)
+{
+ struct vhost_vring_worker ring_worker;
+ struct vhost_worker_state state;
+ struct vhost_worker *worker;
+ struct vhost_virtqueue *vq;
+ long ret;
+ u32 idx;
+
+ if (!dev->use_worker)
+ return -EINVAL;
+
+ if (!vhost_dev_has_owner(dev))
+ return -EINVAL;
+
+ ret = vhost_dev_check_owner(dev);
+ if (ret)
+ return ret;
+
+ switch (ioctl) {
+ /* dev worker ioctls */
+ case VHOST_NEW_WORKER:
+ ret = vhost_new_worker(dev, &state);
+ if (!ret && copy_to_user(argp, &state, sizeof(state)))
+ ret = -EFAULT;
+ return ret;
+ case VHOST_FREE_WORKER:
+ if (copy_from_user(&state, argp, sizeof(state)))
+ return -EFAULT;
+ return vhost_free_worker(dev, &state);
+ /* vring worker ioctls */
+ case VHOST_ATTACH_VRING_WORKER:
+ case VHOST_GET_VRING_WORKER:
+ break;
+ default:
+ return -ENOIOCTLCMD;
+ }
+
+ ret = vhost_get_vq_from_user(dev, argp, &vq, &idx);
+ if (ret)
+ return ret;
+
+ switch (ioctl) {
+ case VHOST_ATTACH_VRING_WORKER:
+ if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
+ ret = -EFAULT;
+ break;
+ }
+
+ ret = vhost_vq_attach_worker(vq, &ring_worker);
+ break;
+ case VHOST_GET_VRING_WORKER:
+ worker = rcu_dereference_check(vq->worker,
+ lockdep_is_held(&dev->mutex));
+ if (!worker) {
+ ret = -EINVAL;
+ break;
+ }
+
+ ring_worker.index = idx;
+ ring_worker.worker_id = worker->id;
+
+ if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
+ ret = -EFAULT;
+ break;
+ default:
+ ret = -ENOIOCTLCMD;
+ break;
+ }
+
+ return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
+
/* Caller should have device mutex */
long vhost_dev_set_owner(struct vhost_dev *dev)
{
- int err;
+ struct vhost_worker *worker;
+ int err, i;
/* Is there an owner already? */
if (vhost_dev_has_owner(dev)) {
@@ -572,20 +875,32 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
vhost_attach_mm(dev);
- if (dev->use_worker) {
- err = vhost_worker_create(dev);
- if (err)
- goto err_worker;
- }
-
err = vhost_dev_alloc_iovecs(dev);
if (err)
goto err_iovecs;
+ if (dev->use_worker) {
+ /*
+ * This should be done last, because vsock can queue work
+ * before VHOST_SET_OWNER so it simplifies the failure path
+ * below since we don't have to worry about vsock queueing
+ * while we free the worker.
+ */
+ worker = vhost_worker_create(dev);
+ if (!worker) {
+ err = -ENOMEM;
+ goto err_worker;
+ }
+
+ for (i = 0; i < dev->nvqs; i++)
+ __vhost_vq_attach_worker(dev->vqs[i], worker);
+ }
+
return 0;
-err_iovecs:
- vhost_worker_free(dev);
+
err_worker:
+ vhost_dev_free_iovecs(dev);
+err_iovecs:
vhost_detach_mm(dev);
err_mm:
return err;
@@ -677,7 +992,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
dev->iotlb = NULL;
vhost_clear_msg(dev);
wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
- vhost_worker_free(dev);
+ vhost_workers_free(dev);
vhost_detach_mm(dev);
}
EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
@@ -1565,21 +1880,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
struct file *eventfp, *filep = NULL;
bool pollstart = false, pollstop = false;
struct eventfd_ctx *ctx = NULL;
- u32 __user *idxp = argp;
struct vhost_virtqueue *vq;
struct vhost_vring_state s;
struct vhost_vring_file f;
u32 idx;
long r;
- r = get_user(idx, idxp);
+ r = vhost_get_vq_from_user(d, argp, &vq, &idx);
if (r < 0)
return r;
- if (idx >= d->nvqs)
- return -ENOBUFS;
-
- idx = array_index_nospec(idx, d->nvqs);
- vq = d->vqs[idx];
if (ioctl == VHOST_SET_VRING_NUM ||
ioctl == VHOST_SET_VRING_ADDR) {
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index fc900be504b38e..f60d5f7bef944e 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -28,8 +28,12 @@ struct vhost_work {
struct vhost_worker {
struct vhost_task *vtsk;
+ /* Used to serialize device wide flushing with worker swapping. */
+ struct mutex mutex;
struct llist_head work_list;
u64 kcov_handle;
+ u32 id;
+ int attachment_cnt;
};
/* Poll a file (eventfd or socket) */
@@ -41,17 +45,17 @@ struct vhost_poll {
struct vhost_work work;
__poll_t mask;
struct vhost_dev *dev;
+ struct vhost_virtqueue *vq;
};
-void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn);
-void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work);
-bool vhost_has_work(struct vhost_dev *dev);
-
void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
- __poll_t mask, struct vhost_dev *dev);
+ __poll_t mask, struct vhost_dev *dev,
+ struct vhost_virtqueue *vq);
int vhost_poll_start(struct vhost_poll *poll, struct file *file);
void vhost_poll_stop(struct vhost_poll *poll);
void vhost_poll_queue(struct vhost_poll *poll);
+
+void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn);
void vhost_dev_flush(struct vhost_dev *dev);
struct vhost_log {
@@ -74,6 +78,7 @@ struct vhost_vring_call {
/* The virtqueue structure describes a queue attached to a device. */
struct vhost_virtqueue {
struct vhost_dev *dev;
+ struct vhost_worker __rcu *worker;
/* The actual ring of buffers. */
struct mutex mutex;
@@ -158,7 +163,6 @@ struct vhost_dev {
struct vhost_virtqueue **vqs;
int nvqs;
struct eventfd_ctx *log_ctx;
- struct vhost_worker worker;
struct vhost_iotlb *umem;
struct vhost_iotlb *iotlb;
spinlock_t iotlb_lock;
@@ -168,6 +172,7 @@ struct vhost_dev {
int iov_limit;
int weight;
int byte_weight;
+ struct xarray worker_xa;
bool use_worker;
int (*msg_handler)(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg);
@@ -188,16 +193,21 @@ void vhost_dev_cleanup(struct vhost_dev *);
void vhost_dev_stop(struct vhost_dev *);
long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp);
long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp);
+long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
+ void __user *argp);
bool vhost_vq_access_ok(struct vhost_virtqueue *vq);
bool vhost_log_access_ok(struct vhost_dev *);
void vhost_clear_msg(struct vhost_dev *dev);
int vhost_get_vq_desc(struct vhost_virtqueue *,
- struct iovec iov[], unsigned int iov_count,
+ struct iovec iov[], unsigned int iov_size,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num);
void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
+void vhost_vq_flush(struct vhost_virtqueue *vq);
+bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work);
+bool vhost_vq_has_work(struct vhost_virtqueue *vq);
bool vhost_vq_is_setup(struct vhost_virtqueue *vq);
int vhost_vq_init_access(struct vhost_virtqueue *);
int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 6578db78f0ae27..817d377a3f360f 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -285,7 +285,7 @@ vhost_transport_send_pkt(struct sk_buff *skb)
atomic_inc(&vsock->queued_replies);
virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb);
- vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
+ vhost_vq_work_queue(&vsock->vqs[VSOCK_VQ_RX], &vsock->send_pkt_work);
rcu_read_unlock();
return len;
@@ -583,7 +583,7 @@ static int vhost_vsock_start(struct vhost_vsock *vsock)
/* Some packets may have been queued before the device was started,
* let's kick the send worker to send them.
*/
- vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
+ vhost_vq_work_queue(&vsock->vqs[VSOCK_VQ_RX], &vsock->send_pkt_work);
mutex_unlock(&vsock->dev.mutex);
return 0;