ThreadLocal精进篇:子线程类InheritableThreadLocal
背景
ThreadLocal可以保证在当前运行线程中的变量不被其他并发下的线程共享。
但是如果在代码中需要使用多线程呢?
ThreadLocal是否该如何保证相关子线程下的数据的传递安全性呢?
InheritableThreadLocal给我们提供了一丝可能。
InheritableThreadLocal
InheritableThreadLocal源码
先简单瞄下源码
package java.lang;import java.lang.ref.*;public class InheritableThreadLocal extends ThreadLocal { protected T childValue(T parentValue) { return parentValue; } ThreadLocalMap getMap(Thread t) {return t.inheritableThreadLocals; } void createMap(Thread t, T firstValue) { t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue); }}
可以很清楚的看到,InheritableThreadLocal是ThreadLcoal的子类。
测试代码
一切以代码为主,我们先上一段代码
package com.chuangyue;import com.google.common.util.concurrent.ThreadFactoryBuilder;import java.util.concurrent.*;/** * 可继承线程本地测试 * */public class InheritableThreadLocalTest { public static void main(String[] args) { ThreadLocal threadLocal = new ThreadLocal(); InheritableThreadLocal inheritableThreadLocal = new InheritableThreadLocal(); threadLocal.set(" I'm threadLocal "); inheritableThreadLocal.set("I'm inheritableThreadLocal"); test01(threadLocal,inheritableThreadLocal); } public static void test01(ThreadLocal threadLocal ,InheritableThreadLocal inheritableThreadLocal){ ThreadFactory namedThreadFactory = new ThreadFactoryBuilder() .setNameFormat("demo-pool-%d").build(); ExecutorService singleThreadPool = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue(1024), namedThreadFactory, new ThreadPoolExecutor.AbortPolicy()); singleThreadPool.execute(()-> { System.out.println("========= begin ========="); System.out.println(Thread.currentThread().getName()+" ThreadLocal: "+threadLocal.get()); System.out.println(Thread.currentThread().getName()+" InheritableThreadLocal: "+inheritableThreadLocal.get()); System.out.println("========= end ========="); }); singleThreadPool.shutdown(); }}
这段测试代码里面,只是设置了ThreadLocal和InheritableThreadLocal两个对象,并且对其进行赋值。
然后再在线程中打印对应的对象结果。
输出结果如下
按照正常的理解,应该是两个值都要输出。
但是在这里,我们可以清楚的看到,ThreadLocal输出值为null
这是为什么呢?
还是老规矩,我们看下源码
ThreadLocal源码
package java.lang;import jdk.internal.misc.TerminatingThreadLocal;import java.lang.ref.*;import java.util.Objects;import java.util.concurrent.atomic.AtomicInteger;import java.util.function.Supplier;public class ThreadLocal {private final int threadLocalHashCode = nextHashCode(); /** * The next hash code to be given out. Updated atomically. Starts at * zero. */ private static AtomicInteger nextHashCode = new AtomicInteger(); private static final int HASH_INCREMENT = 0x61c88647; /** * Returns the next hash code. */ private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); } protected T initialValue() { return null; } public static ThreadLocal withInitial(Supplier supplier) { return new SuppliedThreadLocal(supplier); } /** * Creates a thread local variable. * @see #withInitial(java.util.function.Supplier) */ public ThreadLocal() { } public T get() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) { ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } return setInitialValue(); }boolean isPresent() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); return map != null && map.getEntry(this) != null; } private T setInitialValue() { T value = initialValue(); Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) { map.set(this, value); } else { createMap(t, value); } if (this instanceof TerminatingThreadLocal) { TerminatingThreadLocal.register((TerminatingThreadLocal) this); } return value; }public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) { map.set(this, value); } else { createMap(t, value); } } public void remove() { ThreadLocalMap m = getMap(Thread.currentThread()); if (m != null) { m.remove(this); } } /** * Get the map associated with a ThreadLocal. Overridden in * InheritableThreadLocal. * * @param t the current thread * @return the map */ ThreadLocalMap getMap(Thread t) { return t.threadLocals; } /** * Create the map associated with a ThreadLocal. Overridden in * InheritableThreadLocal. * * @param t the current thread * @param firstValue value for the initial entry of the map */ void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); } /** * Factory method to create map of inherited thread locals. * Designed to be called only from Thread constructor. * * @param parentMap the map associated with parent thread * @return a map containing the parent's inheritable bindings */ static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) { return new ThreadLocalMap(parentMap); } /** * Method childValue is visibly defined in subclass * InheritableThreadLocal, but is internally defined here for the * sake of providing createInheritedMap factory method without * needing to subclass the map class in InheritableThreadLocal. * This technique is preferable to the alternative of embedding * instanceof tests in methods. */ T childValue(T parentValue) { throw new UnsupportedOperationException(); } /** * An extension of ThreadLocal that obtains its initial value from * the specified {@code Supplier}. */ static final class SuppliedThreadLocal extends ThreadLocal { private final Supplier supplier; SuppliedThreadLocal(Supplier supplier) { this.supplier = Objects.requireNonNull(supplier); } @Override protected T initialValue() { return supplier.get(); } } /** * ThreadLocalMap is a customized hash map suitable only for * maintaining thread local values. No operations are exported * outside of the ThreadLocal class. The class is package private to * allow declaration of fields in class Thread. To help deal with * very large and long-lived usages, the hash table entries use * WeakReferences for keys. However, since reference queues are not * used, stale entries are guaranteed to be removed only when * the table starts running out of space. */ static class ThreadLocalMap { /** * The entries in this hash map extend WeakReference, using * its main ref field as the key (which is always a * ThreadLocal object). Note that null keys (i.e. entry.get() * == null) mean that the key is no longer referenced, so the * entry can be expunged from table. Such entries are referred to * as "stale entries" in the code that follows. */ static class Entry extends WeakReference<ThreadLocal> { /** The value associated with this ThreadLocal. */ Object value; Entry(ThreadLocal k, Object v) { super(k); value = v; } } /** * The initial capacity -- MUST be a power of two. */ private static final int INITIAL_CAPACITY = 16; /** * The table, resized as necessary. * table.length MUST always be a power of two. */ private Entry[] table; /** * The number of entries in the table. */ private int size = 0; /** * The next size value at which to resize. */ private int threshold; // Default to 0 /** * Set the resize threshold to maintain at worst a 2/3 load factor. */ private void setThreshold(int len) { threshold = len * 2 / 3; } /** * Increment i modulo len. */ private static int nextIndex(int i, int len) { return ((i + 1 = 0) ? i - 1 : len - 1); } /** * Construct a new map initially containing (firstKey, firstValue). * ThreadLocalMaps are constructed lazily, so we only create * one when we have at least one entry to put in it. */ ThreadLocalMap(ThreadLocal firstKey, Object firstValue) { table = new Entry[INITIAL_CAPACITY]; int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); table[i] = new Entry(firstKey, firstValue); size = 1; setThreshold(INITIAL_CAPACITY); } /** * Construct a new map including all Inheritable ThreadLocals * from given parent map. Called only by createInheritedMap. * * @param parentMap the map associated with parent thread. */ private ThreadLocalMap(ThreadLocalMap parentMap) { Entry[] parentTable = parentMap.table; int len = parentTable.length; setThreshold(len); table = new Entry[len]; for (Entry e : parentTable) { if (e != null) { @SuppressWarnings("unchecked") ThreadLocal key = (ThreadLocal) e.get(); if (key != null) { Object value = key.childValue(e.value); Entry c = new Entry(key, value); int h = key.threadLocalHashCode & (len - 1); while (table[h] != null)h = nextIndex(h, len); table[h] = c; size++; } } } } /** * Get the entry associated with key. This method * itself handles only the fast path: a direct hit of existing * key. It otherwise relays to getEntryAfterMiss. This is * designed to maximize performance for direct hits, in part * by making this method readily inlinable. * * @param key the thread local object * @return the entry associated with key, or null if no such */ private Entry getEntry(ThreadLocal key) { int i = key.threadLocalHashCode & (table.length - 1); Entry e = table[i]; if (e != null && e.get() == key) return e; else return getEntryAfterMiss(key, i, e); } /** * Version of getEntry method for use when key is not found in * its direct hash slot. * * @param key the thread local object * @param i the table index for key's hash code * @param e the entry at table[i] * @return the entry associated with key, or null if no such */ private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null) { ThreadLocal k = e.get(); if (k == key) return e; if (k == null) expungeStaleEntry(i); else i = nextIndex(i, len); e = tab[i]; } return null; } /** * Set the value associated with key. * * @param key the thread local object * @param value the value to be set */ private void set(ThreadLocal key, Object value) { // We don't use a fast path as with get() because it is at // least as common to use set() to create new entries as // it is to replace existing ones, in which case, a fast // path would fail more often than not. Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal k = e.get(); if (k == key) { e.value = value; return; } if (k == null) { replaceStaleEntry(key, value, i); return; } } tab[i] = new Entry(key, value); int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); } /** * Remove the entry for key. */ private void remove(ThreadLocal key) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { if (e.get() == key) { e.clear(); expungeStaleEntry(i); return; } } } /** * Replace a stale entry encountered during a set operation * with an entry for the specified key. The value passed in * the value parameter is stored in the entry, whether or not * an entry already exists for the specified key. * * As a side effect, this method expunges all stale entries in the * "run" containing the stale entry. (A run is a sequence of entries * between two null slots.) * * @param key the key * @param value the value to be associated with key * @param staleSlot index of the first stale entry encountered while * searching for key. */ private void replaceStaleEntry(ThreadLocal key, Object value, int staleSlot) { Entry[] tab = table; int len = tab.length; Entry e; // Back up to check for prior stale entry in current run. // We clean out whole runs at a time to avoid continual // incremental rehashing due to garbage collector freeing // up refs in bunches (i.e., whenever the collector runs). int slotToExpunge = staleSlot; for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len)) if (e.get() == null) slotToExpunge = i; // Find either the key or trailing null slot of run, whichever // occurs first for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal k = e.get(); // If we find key, then we need to swap it // with the stale entry to maintain hash table order. // The newly stale slot, or any other stale slot // encountered above it, can then be sent to expungeStaleEntry // to remove or rehash all of the other entries in run. if (k == key) { e.value = value; tab[i] = tab[staleSlot]; tab[staleSlot] = e; // Start expunge at preceding stale entry if it exists if (slotToExpunge == staleSlot) slotToExpunge = i; cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return; } // If we didn't find stale entry on backward scan, the // first stale entry seen while scanning for key is the // first still present in the run. if (k == null && slotToExpunge == staleSlot) slotToExpunge = i; } // If key not found, put new entry in stale slot tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value); // If there are any other stale entries in run, expunge them if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); } /** * Expunge a stale entry by rehashing any possibly colliding entries * lying between staleSlot and the next null slot. This also expunges * any other stale entries encountered before the trailing null. See * Knuth, Section 6.4 * * @param staleSlot index of slot known to have null key * @return the index of the next null slot after staleSlot * (all between staleSlot and this slot will have been checked * for expunging). */ private int expungeStaleEntry(int staleSlot) { Entry[] tab = table; int len = tab.length; // expunge entry at staleSlot tab[staleSlot].value = null; tab[staleSlot] = null; size--; // Rehash until we encounter null Entry e; int i; for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal k = e.get(); if (k == null) { e.value = null; tab[i] = null; size--; } else { int h = k.threadLocalHashCode & (len - 1); if (h != i) { tab[i] = null; // Unlike Knuth 6.4 Algorithm R, we must scan until // null because multiple entries could have been stale. while (tab[h] != null)h = nextIndex(h, len); tab[h] = e; } } } return i; } /** * Heuristically scan some cells looking for stale entries. * This is invoked when either a new element is added, or * another stale one has been expunged. It performs a * logarithmic number of scans, as a balance between no * scanning (fast but retains garbage) and a number of scans * proportional to number of elements, that would find all * garbage but would cause some insertions to take O(n) time. * * @param i a position known NOT to hold a stale entry. The * scan starts at the element after i. * * @param n scan control: {@code log2(n)} cells are scanned, * unless a stale entry is found, in which case * {@code log2(table.length)-1} additional cells are scanned. * When called from insertions, this parameter is the number * of elements, but when from replaceStaleEntry, it is the * table length. (Note: all this could be changed to be either * more or less aggressive by weighting n instead of just * using straight log n. But this version is simple, fast, and * seems to work well.) * * @return true if any stale entries have been removed. */ private boolean cleanSomeSlots(int i, int n) { boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) { n = len; removed = true; i = expungeStaleEntry(i); } } while ( (n >>>= 1) != 0); return removed; } /** * Re-pack and/or re-size the table. First scan the entire * table removing stale entries. If this doesn't sufficiently * shrink the size of the table, double the table size. */ private void rehash() { expungeStaleEntries(); // Use lower threshold for doubling to avoid hysteresis if (size >= threshold - threshold / 4) resize(); } /** * Double the capacity of the table. */ private void resize() { Entry[] oldTab = table; int oldLen = oldTab.length; int newLen = oldLen * 2; Entry[] newTab = new Entry[newLen]; int count = 0; for (Entry e : oldTab) { if (e != null) { ThreadLocal k = e.get(); if (k == null) { e.value = null; // Help the GC } else { int h = k.threadLocalHashCode & (newLen - 1); while (newTab[h] != null)h = nextIndex(h, newLen); newTab[h] = e; count++; } } } setThreshold(newLen); size = count; table = newTab; } /** * Expunge all stale entries in the table. */ private void expungeStaleEntries() { Entry[] tab = table; int len = tab.length; for (int j = 0; j < len; j++) { Entry e = tab[j]; if (e != null && e.get() == null) expungeStaleEntry(j); } } }}
我们可以很清楚的看到,在ThreadLocal中也存在着以下三个方法:getMap,createMap和chuildValue。
ThreadLocalMap getMap(Thread t) { return t.threadLocals; }void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); } T childValue(T parentValue) { throw new UnsupportedOperationException(); }
我们看下InherittableThreadLocal源码中,对重写的这三个方法的代码注解
/** * 作为创建子线程时父线程值的函数,计算可继承线程局部变量的子线程初始值. * 此方法在子线程启动之前从父线程内调用. * * 此方法仅返回其输入参数,如果需要其他行为,则应重写此方法. * * @param parentValue 父线程值 * @return 子线程初始化值 */ protected T childValue(T parentValue) { return parentValue; }
/** * 获取与ThreadLocal相关的Map对象 * * @param t 当前线程 */ ThreadLocalMap getMap(Thread t) {return t.inheritableThreadLocals; }
/** * 创建一个与ThreadLocal相关的map对象. * * @param t 当前线程 * @param firstValue 表初始条目值. */ void createMap(Thread t, T firstValue) { t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue); }
我们直接跟踪下代码
这个是执行thread的时候,发现当前线程直接为空了。
为什么ThreadLocal值为空
这个问题主要出在Thread上。
我们在Thread代码中找到了这两行
public class Thread implements Runnable { ... ThreadLocal.ThreadLocalMap threadLocals = null; ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;...
每个主线程都会有一个自己的ThreadLocalMap,所以子线程在调用get方法拿值的时候其实访问的是自己的ThreadLocalMap,这个Map和主线程的Map是两个不同的对象,所以肯定是拿不到值的。ThreadLocalMap初始值是null,所以返回就是null了。这个没毛病
那么,inheritableThreadLocals初始化也为null,为什么它就有值呢?
Thread初探
Thread自己怎么玩的
在研究这个之前,我们先看下Thread自己是怎么玩的(初始化)
//构造函数 public Thread(ThreadGroup group, String name) { this(group, null, name, 0); } public Thread(Runnable target, String name) { this(null, target, name, 0); } public Thread(ThreadGroup group, Runnable target, String name, long stackSize) { this(group, target, name, stackSize, null, true); }...
最后它们都甩锅给了它:private Thread
private Thread(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc, boolean inheritThreadLocals) { if (name == null) { throw new NullPointerException("name cannot be null"); } this.name = name; Thread parent = currentThread(); SecurityManager security = System.getSecurityManager(); if (g == null) { /* Determine if it's an applet or not */ /* If there is a security manager, ask the security manager what to do. */ if (security != null) { g = security.getThreadGroup(); } /* If the security manager doesn't have a strong opinion on the matter, use the parent thread group. */ if (g == null) { g = parent.getThreadGroup(); } } /* checkAccess regardless of whether or not threadgroup is explicitly passed in. */ g.checkAccess(); /* * Do we have the required permissions? */ if (security != null) { if (isCCLOverridden(getClass())) { security.checkPermission( SecurityConstants.SUBCLASS_IMPLEMENTATION_PERMISSION); } } g.addUnstarted(); this.group = g; this.daemon = parent.isDaemon(); this.priority = parent.getPriority(); if (security == null || isCCLOverridden(parent.getClass())) this.contextClassLoader = parent.getContextClassLoader(); else this.contextClassLoader = parent.contextClassLoader; this.inheritedAccessControlContext = acc != null ? acc : AccessController.getContext(); this.target = target; setPriority(priority); if (inheritThreadLocals && parent.inheritableThreadLocals != null) this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals); /* Stash the specified stack size in case the VM cares */ this.stackSize = stackSize; /* Set thread ID */ this.tid = nextThreadID(); }。。。
我们可以看到,默认情况下,inheritThreadLocals的值是true。也就是设置inheritableThreadLocal默认是可传递的。
ThreadLocal.createInheritedMap
我们可以看到上面的代码中这么一行代码,来进行创建inheritedMap的。
我们继续剖析createInheritedMap源码
在createInheritedMap中,将所有的父线程中的Map的值,使用for的方式全部复制到子线程中。
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) { return new ThreadLocalMap(parentMap); } /** * 构建一个新的ThreadLocalsMap,这个ThreadLocalsMap包含所有parentMap中 * Inheritable ThreadLocals. 它只能被createInheritedMap调用. * */ private ThreadLocalMap(ThreadLocalMap parentMap) { Entry[] parentTable = parentMap.table; int len = parentTable.length; setThreshold(len); table = new Entry[len]; for (Entry e : parentTable) { if (e != null) { @SuppressWarnings("unchecked") ThreadLocal key = (ThreadLocal) e.get(); if (key != null) { Object value = key.childValue(e.value); Entry c = new Entry(key, value); int h = key.threadLocalHashCode & (len - 1); while (table[h] != null)h = nextIndex(h, len); table[h] = c; size++; } } } }
总结
所以,Thread类中包含的 threadLocals 和 inheritableThreadLocals 两个变量,inheritableThreadLocals 可自动向子线程中传递的ThreadLocal.ThreadLocalMap。这样,在你进行get()操作的时候,自然就能输出值了。